Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logup multiplicity in witness assignment #198

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tracing-flame = "0.2.0"
tracing = "0.1.40"

rand = "0.8"
thread_local = "1.1.8"

[dev-dependencies]
pprof = { version = "0.13", features = ["flamegraph"]}
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Ok(())
}

/// lookup a < b as usigned byte
/// lookup a < b as unsigned byte
pub(crate) fn lookup_ltu_limb8(
&mut self,
res: Expression<E>,
Expand Down
17 changes: 13 additions & 4 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use ceno_emul::StepRecord;
use ff_ext::ExtensionField;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};

use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix};
use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
witness::{LkMultiplicity, RowMajorMatrix},
};

pub mod riscv;

Expand All @@ -18,22 +22,27 @@ pub trait Instruction<E: ExtensionField> {
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: StepRecord,
) -> Result<(), ZKVMError>;

fn assign_instances(
config: &Self::InstructionConfig,
num_witin: usize,
steps: Vec<StepRecord>,
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
) -> Result<(RowMajorMatrix<E::BaseField>, LkMultiplicity), ZKVMError> {
let lk_multiplicity = LkMultiplicity::default();
let mut raw_witin = RowMajorMatrix::<E::BaseField>::new(steps.len(), num_witin);
let raw_witin_iter = raw_witin.par_iter_mut();

raw_witin_iter
.zip_eq(steps.into_par_iter())
.map(|(instance, step)| Self::assign_instance(config, instance, step))
.map(|(instance, step)| {
let mut lk_multiplicity = lk_multiplicity.clone();
Self::assign_instance(config, instance, &mut lk_multiplicity, step)
})
kunxian-xia marked this conversation as resolved.
Show resolved Hide resolved
.collect::<Result<(), ZKVMError>>()?;

Ok(raw_witin)
Ok((raw_witin, lk_multiplicity))
}
}
13 changes: 8 additions & 5 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::{
instructions::Instruction,
set_val,
uint::UIntValue,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

Expand Down Expand Up @@ -151,13 +152,14 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: StepRecord,
) -> Result<(), ZKVMError> {
// TODO use fields from step
set_val!(instance, config.pc, 1);
set_val!(instance, config.ts, 2);
let addend_0 = UIntValue::new(step.rs1().unwrap().value);
let addend_1 = UIntValue::new(step.rs2().unwrap().value);
let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value);
let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value);
config
.prev_rd_value
.assign_limbs(instance, [0, 0].iter().map(E::BaseField::from).collect());
Expand All @@ -167,7 +169,7 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
config
.addend_1
.assign_limbs(instance, addend_1.u16_fields());
let carries = addend_0.add_u16_carries(&addend_1);
let (_, carries) = addend_0.add(&addend_1, lk_multiplicity, true);
config.outcome.assign_carries(
instance,
carries
Expand Down Expand Up @@ -199,6 +201,7 @@ impl<E: ExtensionField> Instruction<E> for SubInstruction {
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
_lk_multiplicity: &mut LkMultiplicity,
_step: StepRecord,
) -> Result<(), ZKVMError> {
// TODO use field from step
Expand Down Expand Up @@ -263,7 +266,7 @@ mod test {
.unwrap()
.unwrap();

let raw_witin = AddInstruction::assign_instances(
let (raw_witin, _) = AddInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord {
Expand Down Expand Up @@ -310,7 +313,7 @@ mod test {
.unwrap()
.unwrap();

let raw_witin = AddInstruction::assign_instances(
let (raw_witin, _) = AddInstruction::assign_instances(
&config,
cb.cs.num_witin as usize,
vec![StepRecord {
Expand Down
4 changes: 3 additions & 1 deletion ceno_zkvm/src/instructions/riscv/blt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
},
set_val,
utils::{i64_to_base, limb_u8_to_u16},
witness::LkMultiplicity,
};

use super::{
Expand Down Expand Up @@ -222,6 +223,7 @@ impl<E: ExtensionField> Instruction<E> for BltInstruction {
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [std::mem::MaybeUninit<E::BaseField>],
_lk_multiplicity: &mut LkMultiplicity,
_step: ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
// take input from _step
Expand Down Expand Up @@ -250,7 +252,7 @@ mod test {
let num_wits = circuit_builder.cs.num_witin as usize;
// generate mock witness
let num_instances = 1 << 4;
let raw_witin = BltInstruction::assign_instances(
let (raw_witin, _) = BltInstruction::assign_instances(
&config,
num_wits,
vec![StepRecord::default(); num_instances],
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(box_patterns)]
#![feature(stmt_expr_attributes)]
#![feature(variant_count)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for mem::variant_count::()] to count entries of ROMType


pub mod error;
pub mod instructions;
Expand Down
48 changes: 40 additions & 8 deletions ceno_zkvm/src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
error::{UtilError, ZKVMError},
expression::{Expression, ToExpr, WitIn},
utils::add_one_to_big_num,
witness::LkMultiplicity,
};
use ark_std::iterable::Iterable;
use constants::BYTE_BIT_WIDTH;
Expand Down Expand Up @@ -476,13 +477,29 @@ impl<T: Into<u64> + Copy> UIntValue<T> {
mem::size_of::<T>() / u16_bytes
};

pub fn new(val: T) -> Self {
#[allow(dead_code)]
pub fn new(val: T, lkm: &mut LkMultiplicity) -> Self {
let uint = UIntValue::<T> {
val,
limbs: Self::split_to_u16(val),
};
Self::assert_u16(&uint.limbs, lkm);
uint
}

pub fn new_unchecked(val: T) -> Self {
UIntValue::<T> {
val,
limbs: Self::split_to_u16(val),
}
}

fn assert_u16(v: &[u16], lkm: &mut LkMultiplicity) {
v.iter().for_each(|v| {
lkm.assert_ux::<16>(*v as u64);
})
}

fn split_to_u16(value: T) -> Vec<u16> {
let value: u64 = value.into(); // Convert to u64 for generality
(0..Self::LIMBS)
Expand All @@ -502,20 +519,35 @@ impl<T: Into<u64> + Copy> UIntValue<T> {
self.limbs.iter().map(|v| F::from(*v as u64)).collect_vec()
}

pub fn add_u16_carries(&self, rhs: &Self) -> Vec<bool> {
self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold(
pub fn add(
&self,
rhs: &Self,
lkm: &mut LkMultiplicity,
with_overflow: bool,
) -> (Vec<u16>, Vec<bool>) {
let res = self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold(
vec![],
|mut acc, (a_limb, b_limb)| {
let (a, b) = a_limb.overflowing_add(*b_limb);
if let Some(prev_carry) = acc.last() {
let (_, d) = a.overflowing_add(*prev_carry as u16);
acc.push(b || d);
if let Some((_, prev_carry)) = acc.last() {
let (e, d) = a.overflowing_add(*prev_carry as u16);
acc.push((e, b || d));
} else {
acc.push(b);
acc.push((a, b));
}
// range check
if let Some((limb, _)) = acc.last() {
lkm.assert_ux::<16>(*limb as u64);
};
acc
},
)
);
let (limbs, mut carries): (Vec<u16>, Vec<bool>) = res.into_iter().unzip();
if !with_overflow {
carries.resize(carries.len() - 1, false);
}
carries.iter().for_each(|c| lkm.assert_ux::<16>(*c as u64));
(limbs, carries)
}
}

Expand Down
106 changes: 106 additions & 0 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
use std::{
array,
cell::RefCell,
collections::HashMap,
mem::{self, MaybeUninit},
slice::ChunksMut,
sync::Arc,
};

use multilinear_extensions::util::create_uninit_vec;
use rayon::{
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
slice::ParallelSliceMut,
};
use thread_local::ThreadLocal;

use crate::structs::ROMType;

#[macro_export]
macro_rules! set_val {
Expand Down Expand Up @@ -51,3 +58,102 @@ impl<T: Sized + Sync + Clone + Send> RowMajorMatrix<T> {
.collect()
}
}

/// A lock-free thread safe struct to count logup multiplicity for each ROM type
/// Lock-free by thread-local such that each thread will only have its local copy
/// struct is cloneable, for internallly it use Arc so the clone will be low cost
#[derive(Clone, Default)]
#[allow(clippy::type_complexity)]
pub struct LkMultiplicity {
multiplicity: Arc<ThreadLocal<RefCell<[HashMap<u64, usize>; mem::variant_count::<ROMType>()]>>>,
}

#[allow(dead_code)]
impl LkMultiplicity {
/// assert within range
#[inline(always)]
pub fn assert_ux<const C: usize>(&mut self, v: u64) {
match C {
16 => self.assert_u16(v),
8 => self.assert_byte(v),
5 => self.assert_u5(v),
_ => panic!("Unsupported bit range"),
}
}

fn assert_u5(&mut self, v: u64) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U5 as usize]
.entry(v)
.or_default()) += 1;
}

fn assert_u16(&mut self, v: u64) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U16 as usize]
.entry(v)
.or_default()) += 1;
}

fn assert_byte(&mut self, v: u64) {
let v = v * (1 << u8::BITS);
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U16 as usize]
.entry(v)
.or_default()) += 1;
}

/// lookup a < b as unsigned byte
pub fn lookup_ltu_limb8(&mut self, a: u64, b: u64) {
let key = a.wrapping_mul(256) + b;
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::Ltu as usize]
.entry(key)
.or_default()) += 1;
}

/// merge result from multiple thread local to single result
fn into_finalize_result(self) -> [HashMap<u64, usize>; mem::variant_count::<ROMType>()] {
Arc::try_unwrap(self.multiplicity)
.unwrap()
.into_iter()
.fold(array::from_fn(|_| HashMap::new()), |mut x, y| {
x.iter_mut().zip(y.borrow().iter()).for_each(|(m1, m2)| {
for (key, value) in m2 {
*m1.entry(*key).or_insert(0) += value;
}
});
x
})
}
}

#[cfg(test)]
mod tests {
use std::thread;

use crate::{structs::ROMType, witness::LkMultiplicity};

#[test]
fn test_lk_multiplicity_threads() {
// TODO figure out a way to verify thread_local hit/miss in unittest env
let lkm = LkMultiplicity::default();
let thread_count = 20;
// each thread calling assert_byte once
for _ in 0..thread_count {
let mut lkm = lkm.clone();
thread::spawn(move || lkm.assert_byte(8u64)).join().unwrap();
}
let res = lkm.into_finalize_result();
// check multiplicity counts of assert_byte
assert_eq!(res[ROMType::U16 as usize][&(8 << 8)], thread_count);
}
}
Loading