Skip to content

Commit

Permalink
avoid using lookup_ltu_limb8 instead use ux lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
zemse committed Sep 6, 2024
1 parent c05022c commit 1ab8dea
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 27 deletions.
28 changes: 17 additions & 11 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,34 +266,40 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Ok(())
}

pub(crate) fn less_than<N, NR>(
pub(crate) fn less_than<N, NR, const C: usize>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
) -> Result<WitIn, ZKVMError>
) -> Result<(WitIn, WitIn), ZKVMError>
where
NR: Into<String> + Display,
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
{
self.namespace(
|| "less_than",
|cb| {
let name = name_fn();
let is_lt = cb.create_witin(|| format!("{name} witin"))?;
// TODO add name_fn to lookup_ltu_limb8, not done yet to avoid merge conflicts
cb.lookup_ltu_limb8(is_lt.expr(), lhs, rhs)?;
Ok(is_lt)
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
let diff = cb.create_witin(|| format!("{name} diff witin"))?;
let range = Expression::Constant(2u64.pow(C as u32).into());
cb.require_equal(
|| name.clone(),
lhs - rhs,
diff.expr() - is_lt.expr() * range,
)?;
cb.assert_ux::<_, _, C>(|| name, diff.expr())?;
Ok((is_lt, diff))
},
)
}

pub(crate) fn assert_less_than<N, NR>(
pub(crate) fn assert_less_than<N, NR, const C: usize>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
) -> Result<WitIn, ZKVMError>
) -> Result<(WitIn, WitIn), ZKVMError>
where
NR: Into<String> + Clone + Display,
N: FnOnce() -> NR,
Expand All @@ -302,9 +308,9 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
|| "assert_less_than",
|cb| {
let name = name_fn();
let is_lt = cb.less_than(|| name.clone(), lhs, rhs)?;
let (is_lt, diff) = cb.less_than::<_, _, C>(|| name.clone(), lhs, rhs)?;
cb.require_one(|| name, is_lt.expr())?;
Ok(is_lt)
Ok((is_lt, diff))
},
)
}
Expand Down
29 changes: 19 additions & 10 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ pub struct InstructionConfig<E: ExtensionField> {
pub prev_rs1_ts: WitIn,
pub prev_rs2_ts: WitIn,
pub prev_rd_ts: WitIn,
pub lt_wtns_rs1: WitIn,
pub lt_wtns_rs2: WitIn,
pub lt_wtns_rd: WitIn,
pub lt_wtns_rs1: (WitIn, WitIn),
pub lt_wtns_rs2: (WitIn, WitIn),
pub lt_wtns_rd: (WitIn, WitIn),
phantom: PhantomData<E>,
}

Expand Down Expand Up @@ -120,18 +120,21 @@ fn add_sub_gadget<E: ExtensionField, const IS_ADD: bool>(
let next_ts = ts + 1.into();
circuit_builder.state_out(next_pc, next_ts)?;

let lt_wtns_rs1 = circuit_builder.assert_less_than(
let lt_wtns_rs1 = circuit_builder.assert_less_than::<_, _, 16>(
|| "prev_rs1_ts < ts",
prev_rs1_ts.expr(),
cur_ts.expr(),
)?;
let lt_wtns_rs2 = circuit_builder.assert_less_than(
let lt_wtns_rs2 = circuit_builder.assert_less_than::<_, _, 16>(
|| "prev_rs2_ts < ts",
prev_rs2_ts.expr(),
cur_ts.expr(),
)?;
let lt_wtns_rd =
circuit_builder.assert_less_than(|| "prev_rd_ts < ts", prev_rd_ts.expr(), cur_ts.expr())?;
let lt_wtns_rd = circuit_builder.assert_less_than::<_, _, 16>(
|| "prev_rd_ts < ts",
prev_rd_ts.expr(),
cur_ts.expr(),
)?;

Ok(InstructionConfig {
pc,
Expand Down Expand Up @@ -195,9 +198,15 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {
set_val!(instance, config.prev_rs1_ts, 2);
set_val!(instance, config.prev_rs2_ts, 2);
set_val!(instance, config.prev_rd_ts, 2);
set_val!(instance, config.lt_wtns_rs1, 1);
set_val!(instance, config.lt_wtns_rs2, 1);
set_val!(instance, config.lt_wtns_rd, 1);

set_val!(instance, config.lt_wtns_rs1.0, 1);
set_val!(instance, config.lt_wtns_rs1.1, 2u64.pow(16) - 2 + 1); // range - lhs + rhs

set_val!(instance, config.lt_wtns_rs2.0, 1);
set_val!(instance, config.lt_wtns_rs2.1, 2u64.pow(16) - 3 + 2); // range - lhs + rhs

set_val!(instance, config.lt_wtns_rd.0, 1);
set_val!(instance, config.lt_wtns_rd.1, 2u64.pow(16) - 3 + 2); // range - lhs + rhs
Ok(())
}
}
Expand Down
12 changes: 6 additions & 6 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub fn load_u5_table<E: ExtensionField>(
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
for i in 0..32 {
for i in 0..(1 << 5) {
let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::U5 as u64)),
i.into(),
Expand All @@ -340,8 +340,8 @@ pub fn load_u16_table<E: ExtensionField>(
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
t_vec.reserve(u16::MAX as usize);
for i in 0..(u16::MAX as usize) {
t_vec.reserve(1 << 16);
for i in 0..(1 << 16) {
let rlc_record = cb.rlc_chip_record(vec![
Expression::Constant(E::BaseField::from(ROMType::U16 as u64)),
i.into(),
Expand All @@ -356,9 +356,9 @@ pub fn load_lt_table<E: ExtensionField>(
cb: &CircuitBuilder<E>,
challenge: [E; 2],
) {
t_vec.reserve(u16::MAX as usize);
for lhs in 0..(u8::MAX as usize) {
for rhs in 0..(u8::MAX as usize) {
t_vec.reserve(1 << 16);
for lhs in 0..(1 << 8) {
for rhs in 0..(1 << 8) {
let is_lt = if lhs < rhs { 1 } else { 0 };
let lhs_rhs = lhs * 256 + rhs;
let rlc_record = cb.rlc_chip_record(vec![
Expand Down

0 comments on commit 1ab8dea

Please sign in to comment.