diff --git a/Cargo.lock b/Cargo.lock index 002169d90..ea8775ddd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -300,6 +300,7 @@ dependencies = [ "strum 0.25.0", "strum_macros 0.25.3", "sumcheck", + "thread_local", "tracing", "tracing-flame", "tracing-subscriber", diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index bd13aab86..eb9d256f4 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -28,6 +28,7 @@ tracing-flame = "0.2.0" tracing = "0.1.40" rand = "0.8" +thread_local = "1.1.8" [dev-dependencies] pprof = { version = "0.13", features = ["flamegraph"]} diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 9385c0294..0b4a672a8 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -246,7 +246,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } - /// lookup a < b as usigned byte + /// lookup a < b as unsigned byte pub(crate) fn lookup_ltu_limb8( &mut self, res: Expression, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 7b44d052c..7a37fc39a 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -4,7 +4,11 @@ use ceno_emul::StepRecord; use ff_ext::ExtensionField; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; -use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix}; +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + witness::{LkMultiplicity, RowMajorMatrix}, +}; pub mod riscv; @@ -18,6 +22,7 @@ pub trait Instruction { fn assign_instance( config: &Self::InstructionConfig, instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, step: StepRecord, ) -> Result<(), ZKVMError>; @@ -25,15 +30,19 @@ pub trait Instruction { config: &Self::InstructionConfig, num_witin: usize, steps: Vec, - ) -> Result, ZKVMError> { + ) -> Result<(RowMajorMatrix, LkMultiplicity), ZKVMError> { + let lk_multiplicity = LkMultiplicity::default(); let mut raw_witin = RowMajorMatrix::::new(steps.len(), num_witin); let raw_witin_iter = raw_witin.par_iter_mut(); raw_witin_iter .zip_eq(steps.into_par_iter()) - .map(|(instance, step)| Self::assign_instance(config, instance, step)) + .map(|(instance, step)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + Self::assign_instance(config, instance, &mut lk_multiplicity, step) + }) .collect::>()?; - Ok(raw_witin) + Ok((raw_witin, lk_multiplicity)) } } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index 47bbd49d8..84cf4f36c 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -17,6 +17,7 @@ use crate::{ instructions::Instruction, set_val, uint::UIntValue, + witness::LkMultiplicity, }; use core::mem::MaybeUninit; @@ -151,13 +152,14 @@ impl Instruction for AddInstruction { fn assign_instance( config: &Self::InstructionConfig, instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, step: StepRecord, ) -> Result<(), ZKVMError> { // TODO use fields from step set_val!(instance, config.pc, 1); set_val!(instance, config.ts, 2); - let addend_0 = UIntValue::new(step.rs1().unwrap().value); - let addend_1 = UIntValue::new(step.rs2().unwrap().value); + let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value); + let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value); config .prev_rd_value .assign_limbs(instance, [0, 0].iter().map(E::BaseField::from).collect()); @@ -167,7 +169,7 @@ impl Instruction for AddInstruction { config .addend_1 .assign_limbs(instance, addend_1.u16_fields()); - let carries = addend_0.add_u16_carries(&addend_1); + let (_, carries) = addend_0.add(&addend_1, lk_multiplicity, true); config.outcome.assign_carries( instance, carries @@ -199,6 +201,7 @@ impl Instruction for SubInstruction { fn assign_instance( config: &Self::InstructionConfig, instance: &mut [MaybeUninit], + _lk_multiplicity: &mut LkMultiplicity, _step: StepRecord, ) -> Result<(), ZKVMError> { // TODO use field from step @@ -263,7 +266,7 @@ mod test { .unwrap() .unwrap(); - let raw_witin = AddInstruction::assign_instances( + let (raw_witin, _) = AddInstruction::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord { @@ -310,7 +313,7 @@ mod test { .unwrap() .unwrap(); - let raw_witin = AddInstruction::assign_instances( + let (raw_witin, _) = AddInstruction::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord { diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index fc0181d4b..007126d4d 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -15,6 +15,7 @@ use crate::{ }, set_val, utils::{i64_to_base, limb_u8_to_u16}, + witness::LkMultiplicity, }; use super::{ @@ -222,6 +223,7 @@ impl Instruction for BltInstruction { fn assign_instance( config: &Self::InstructionConfig, instance: &mut [std::mem::MaybeUninit], + _lk_multiplicity: &mut LkMultiplicity, _step: ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { // take input from _step @@ -250,7 +252,7 @@ mod test { let num_wits = circuit_builder.cs.num_witin as usize; // generate mock witness let num_instances = 1 << 4; - let raw_witin = BltInstruction::assign_instances( + let (raw_witin, _) = BltInstruction::assign_instances( &config, num_wits, vec![StepRecord::default(); num_instances], diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 3e9612f1a..811132e63 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,5 +1,6 @@ #![feature(box_patterns)] #![feature(stmt_expr_attributes)] +#![feature(variant_count)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 8a27fdb9a..44927014d 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -7,6 +7,7 @@ use crate::{ error::{UtilError, ZKVMError}, expression::{Expression, ToExpr, WitIn}, utils::add_one_to_big_num, + witness::LkMultiplicity, }; use ark_std::iterable::Iterable; use constants::BYTE_BIT_WIDTH; @@ -476,13 +477,29 @@ impl + Copy> UIntValue { mem::size_of::() / u16_bytes }; - pub fn new(val: T) -> Self { + #[allow(dead_code)] + pub fn new(val: T, lkm: &mut LkMultiplicity) -> Self { + let uint = UIntValue:: { + val, + limbs: Self::split_to_u16(val), + }; + Self::assert_u16(&uint.limbs, lkm); + uint + } + + pub fn new_unchecked(val: T) -> Self { UIntValue:: { val, limbs: Self::split_to_u16(val), } } + fn assert_u16(v: &[u16], lkm: &mut LkMultiplicity) { + v.iter().for_each(|v| { + lkm.assert_ux::<16>(*v as u64); + }) + } + fn split_to_u16(value: T) -> Vec { let value: u64 = value.into(); // Convert to u64 for generality (0..Self::LIMBS) @@ -502,20 +519,35 @@ impl + Copy> UIntValue { self.limbs.iter().map(|v| F::from(*v as u64)).collect_vec() } - pub fn add_u16_carries(&self, rhs: &Self) -> Vec { - self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold( + pub fn add( + &self, + rhs: &Self, + lkm: &mut LkMultiplicity, + with_overflow: bool, + ) -> (Vec, Vec) { + let res = self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold( vec![], |mut acc, (a_limb, b_limb)| { let (a, b) = a_limb.overflowing_add(*b_limb); - if let Some(prev_carry) = acc.last() { - let (_, d) = a.overflowing_add(*prev_carry as u16); - acc.push(b || d); + if let Some((_, prev_carry)) = acc.last() { + let (e, d) = a.overflowing_add(*prev_carry as u16); + acc.push((e, b || d)); } else { - acc.push(b); + acc.push((a, b)); } + // range check + if let Some((limb, _)) = acc.last() { + lkm.assert_ux::<16>(*limb as u64); + }; acc }, - ) + ); + let (limbs, mut carries): (Vec, Vec) = res.into_iter().unzip(); + if !with_overflow { + carries.resize(carries.len() - 1, false); + } + carries.iter().for_each(|c| lkm.assert_ux::<16>(*c as u64)); + (limbs, carries) } } diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 9b44586da..718145ddc 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -1,6 +1,10 @@ use std::{ + array, + cell::RefCell, + collections::HashMap, mem::{self, MaybeUninit}, slice::ChunksMut, + sync::Arc, }; use multilinear_extensions::util::create_uninit_vec; @@ -8,6 +12,9 @@ use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, slice::ParallelSliceMut, }; +use thread_local::ThreadLocal; + +use crate::structs::ROMType; #[macro_export] macro_rules! set_val { @@ -51,3 +58,102 @@ impl RowMajorMatrix { .collect() } } + +/// A lock-free thread safe struct to count logup multiplicity for each ROM type +/// Lock-free by thread-local such that each thread will only have its local copy +/// struct is cloneable, for internallly it use Arc so the clone will be low cost +#[derive(Clone, Default)] +#[allow(clippy::type_complexity)] +pub struct LkMultiplicity { + multiplicity: Arc; mem::variant_count::()]>>>, +} + +#[allow(dead_code)] +impl LkMultiplicity { + /// assert within range + #[inline(always)] + pub fn assert_ux(&mut self, v: u64) { + match C { + 16 => self.assert_u16(v), + 8 => self.assert_byte(v), + 5 => self.assert_u5(v), + _ => panic!("Unsupported bit range"), + } + } + + fn assert_u5(&mut self, v: u64) { + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::U5 as usize] + .entry(v) + .or_default()) += 1; + } + + fn assert_u16(&mut self, v: u64) { + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::U16 as usize] + .entry(v) + .or_default()) += 1; + } + + fn assert_byte(&mut self, v: u64) { + let v = v * (1 << u8::BITS); + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::U16 as usize] + .entry(v) + .or_default()) += 1; + } + + /// lookup a < b as unsigned byte + pub fn lookup_ltu_limb8(&mut self, a: u64, b: u64) { + let key = a.wrapping_mul(256) + b; + let multiplicity = self + .multiplicity + .get_or(|| RefCell::new(array::from_fn(|_| HashMap::new()))); + (*multiplicity.borrow_mut()[ROMType::Ltu as usize] + .entry(key) + .or_default()) += 1; + } + + /// merge result from multiple thread local to single result + fn into_finalize_result(self) -> [HashMap; mem::variant_count::()] { + Arc::try_unwrap(self.multiplicity) + .unwrap() + .into_iter() + .fold(array::from_fn(|_| HashMap::new()), |mut x, y| { + x.iter_mut().zip(y.borrow().iter()).for_each(|(m1, m2)| { + for (key, value) in m2 { + *m1.entry(*key).or_insert(0) += value; + } + }); + x + }) + } +} + +#[cfg(test)] +mod tests { + use std::thread; + + use crate::{structs::ROMType, witness::LkMultiplicity}; + + #[test] + fn test_lk_multiplicity_threads() { + // TODO figure out a way to verify thread_local hit/miss in unittest env + let lkm = LkMultiplicity::default(); + let thread_count = 20; + // each thread calling assert_byte once + for _ in 0..thread_count { + let mut lkm = lkm.clone(); + thread::spawn(move || lkm.assert_byte(8u64)).join().unwrap(); + } + let res = lkm.into_finalize_result(); + // check multiplicity counts of assert_byte + assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count); + } +}