Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Op tables definitions #243

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,17 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Ok(())
}

/// lookup a < b as unsigned byte
/// Assert that `(a < b) == res as bool`, that `a, b` are unsigned bytes, and that `res` is 0 or 1.
pub(crate) fn lookup_ltu_limb8(
&mut self,
a: Expression<E>,
b: Expression<E>,
res: Expression<E>,
) -> Result<(), ZKVMError> {
let key = a * 256.into() + b;
let items: Vec<Expression<E>> = vec![
Expression::Constant(E::BaseField::from(ROMType::Ltu as u64)),
key,
a,
b,
res,
];
let rlc_record = self.rlc_chip_record(items);
Expand Down
101 changes: 18 additions & 83 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use crate::{
expression::Expression,
scheme::utils::eval_by_expr_with_fixed,
structs::{ROMType, WitnessId},
tables::{AndTable, OpsTable, ProgramTableCircuit, TableCircuit},
tables::{
AndTable, LtuTable, OpsTable, OrTable, ProgramTableCircuit, RangeTable, TableCircuit,
U16Table, U5Table, U8Table, XorTable,
},
};
use ark_std::test_rng;
use ceno_emul::{ByteAddr, CENO_PLATFORM};
Expand Down Expand Up @@ -233,76 +236,27 @@ pub(crate) struct MockProver<E: ExtensionField> {
}

fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) -> HashSet<Vec<u8>> {
fn load_u5_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for i in 0..(1 << 5) {
let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::U5 as u64)),
i.into(),
]);
let rlc_record = eval_by_expr(&[], &challenge, &rlc_record);
t_vec.push(rlc_record.to_repr().as_ref().to_vec());
}
}

fn load_u8_table<E: ExtensionField>(
fn load_range_table<RANGE: RangeTable, E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for i in 0..=u8::MAX as usize {
let rlc_record = cb.rlc_chip_record(vec![(ROMType::U8 as usize).into(), i.into()]);
for i in RANGE::content() {
let rlc_record =
cb.rlc_chip_record(vec![(RANGE::ROM_TYPE as usize).into(), (i as usize).into()]);
let rlc_record = eval_by_expr(&[], &challenge, &rlc_record);
t_vec.push(rlc_record.to_repr().as_ref().to_vec());
}
}

fn load_u16_table<E: ExtensionField>(
fn load_op_table<OP: OpsTable, E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for i in 0..=u16::MAX as usize {
for [a, b, c] in OP::content() {
let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::U16 as u64)),
i.into(),
]);
let rlc_record = eval_by_expr(&[], &challenge, &rlc_record);
t_vec.push(rlc_record.to_repr().as_ref().to_vec());
}
}

fn load_lt_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for lhs in 0..(1 << 8) {
for rhs in 0..(1 << 8) {
let is_lt = if lhs < rhs { 1 } else { 0 };
let lhs_rhs = lhs * 256 + rhs;
let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::Ltu as u64)),
lhs_rhs.into(),
is_lt.into(),
]);
let rlc_record = eval_by_expr(&[], &challenge, &rlc_record);
t_vec.push(rlc_record.to_repr().as_ref().to_vec());
}
}
}

fn load_and_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for [a, b, c] in AndTable::content() {
let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::And as u64)),
(OP::ROM_TYPE as usize).into(),
(a as usize).into(),
(b as usize).into(),
(c as usize).into(),
Expand All @@ -312,25 +266,6 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
}
}

fn load_ltu_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for i in 0..=u16::MAX as usize {
let a = i >> 8;
let b = i & 0xFF;
let c = (a < b) as usize;
let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::Ltu as u64)),
i.into(),
c.into(),
]);
let rlc_record = eval_by_expr(&[], &challenge, &rlc_record);
t_vec.push(rlc_record.to_repr().as_ref().to_vec());
}
}

fn load_program_table<E: ExtensionField>(
t_vec: &mut Vec<Vec<u8>>,
_cb: &CircuitBuilder<E>,
Expand All @@ -355,13 +290,13 @@ fn load_tables<E: ExtensionField>(cb: &CircuitBuilder<E>, challenge: [E; 2]) ->
}

let mut table_vec = vec![];
// TODO load more tables here
load_u5_table(&mut table_vec, cb, challenge);
load_u8_table(&mut table_vec, cb, challenge);
load_u16_table(&mut table_vec, cb, challenge);
load_lt_table(&mut table_vec, cb, challenge);
load_and_table(&mut table_vec, cb, challenge);
load_ltu_table(&mut table_vec, cb, challenge);
load_range_table::<U5Table, _>(&mut table_vec, cb, challenge);
load_range_table::<U8Table, _>(&mut table_vec, cb, challenge);
load_range_table::<U16Table, _>(&mut table_vec, cb, challenge);
load_op_table::<AndTable, _>(&mut table_vec, cb, challenge);
load_op_table::<OrTable, _>(&mut table_vec, cb, challenge);
load_op_table::<XorTable, _>(&mut table_vec, cb, challenge);
load_op_table::<LtuTable, _>(&mut table_vec, cb, challenge);
load_program_table(&mut table_vec, cb, challenge);
HashSet::from_iter(table_vec)
}
Expand Down
6 changes: 4 additions & 2 deletions ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ pub enum ROMType {
U5 = 0, // 2^5 = 32
U8, // 2^8 = 256
U16, // 2^16 = 65,536
And, // a ^ b where a, b are bytes
Ltu, // a <(usign) b where a, b are bytes
And, // a & b where a, b are bytes
Or, // a | b where a, b are bytes
Xor, // a ^ b where a, b are bytes
Ltu, // a <(usign) b where a, b are bytes and the result is 0/1.
Instruction, // Decoded instruction from the fixed program.
}

Expand Down
54 changes: 54 additions & 0 deletions ceno_zkvm/src/tables/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,57 @@ impl OpsTable for AndTable {
}
}
pub type AndTableCircuit<E> = OpsTableCircuit<E, AndTable>;

pub struct OrTable;
impl OpsTable for OrTable {
const ROM_TYPE: ROMType = ROMType::Or;
fn len() -> usize {
1 << 16
}

fn content() -> Vec<[u64; 3]> {
(0..Self::len() as u64)
.map(|i| {
let (a, b) = Self::unpack(i);
[a, b, a | b]
})
.collect()
}
}
pub type OrTableCircuit<E> = OpsTableCircuit<E, OrTable>;

pub struct XorTable;
impl OpsTable for XorTable {
const ROM_TYPE: ROMType = ROMType::Xor;
fn len() -> usize {
1 << 16
}

fn content() -> Vec<[u64; 3]> {
(0..Self::len() as u64)
.map(|i| {
let (a, b) = Self::unpack(i);
[a, b, a ^ b]
})
.collect()
}
}
pub type XorTableCircuit<E> = OpsTableCircuit<E, XorTable>;

pub struct LtuTable;
impl OpsTable for LtuTable {
const ROM_TYPE: ROMType = ROMType::Ltu;
fn len() -> usize {
1 << 16
}

fn content() -> Vec<[u64; 3]> {
(0..Self::len() as u64)
.map(|i| {
let (a, b) = Self::unpack(i);
[a, b, if a < b { 1 } else { 0 }]
})
.collect()
}
}
pub type LtuTableCircuit<E> = OpsTableCircuit<E, LtuTable>;
2 changes: 1 addition & 1 deletion ceno_zkvm/src/tables/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
mod range_impl;

mod range_circuit;
use range_circuit::{RangeTable, RangeTableCircuit};
pub use range_circuit::{RangeTable, RangeTableCircuit};

use crate::structs::ROMType;

Expand Down
55 changes: 10 additions & 45 deletions ceno_zkvm/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use thread_local::ThreadLocal;

use crate::{
structs::ROMType,
tables::{AndTable, OpsTable},
tables::{AndTable, LtuTable, OpsTable},
};

#[macro_export]
Expand Down Expand Up @@ -103,63 +103,26 @@ impl LkMultiplicity {
#[inline(always)]
pub fn assert_ux<const C: usize>(&mut self, v: u64) {
match C {
16 => self.assert_u16(v),
8 => self.assert_byte(v),
5 => self.assert_u5(v),
16 => self.increment(ROMType::U16, v),
8 => self.increment(ROMType::U8, v),
5 => self.increment(ROMType::U5, v),
_ => panic!("Unsupported bit range"),
}
}

fn assert_u5(&mut self, v: u64) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U5 as usize]
.entry(v)
.or_default()) += 1;
}

fn assert_u16(&mut self, v: u64) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U16 as usize]
.entry(v)
.or_default()) += 1;
}

fn assert_byte(&mut self, v: u64) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::U8 as usize]
.entry(v)
.or_default()) += 1;
}

/// lookup a AND b
pub fn lookup_and_byte(&mut self, a: u64, b: u64) {
self.increment(ROMType::And, AndTable::pack(a, b));
}

/// lookup a < b as unsigned byte
pub fn lookup_ltu_limb8(&mut self, a: u64, b: u64) {
let key = a.wrapping_mul(256) + b;
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::Ltu as usize]
.entry(key)
.or_default()) += 1;
self.increment(ROMType::Ltu, LtuTable::pack(a, b));
}

/// Fetch instruction at pc
pub fn fetch(&mut self, pc: u32) {
let multiplicity = self
.multiplicity
.get_or(|| RefCell::new(array::from_fn(|_| HashMap::new())));
(*multiplicity.borrow_mut()[ROMType::Instruction as usize]
.entry(pc as u64)
.or_default()) += 1;
self.increment(ROMType::Instruction, pc as u64);
}

/// merge result from multiple thread local to single result
Expand Down Expand Up @@ -201,7 +164,9 @@ mod tests {
// each thread calling assert_byte once
for _ in 0..thread_count {
let mut lkm = lkm.clone();
thread::spawn(move || lkm.assert_byte(8u64)).join().unwrap();
thread::spawn(move || lkm.assert_ux::<8>(8u64))
.join()
.unwrap();
}
let res = lkm.into_finalize_result();
// check multiplicity counts of assert_byte
Expand Down
Loading