Skip to content

Commit

Permalink
refactor lt config to gadget and blt (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 authored Sep 26, 2024
1 parent d0d4e66 commit 5f4713d
Show file tree
Hide file tree
Showing 18 changed files with 413 additions and 237 deletions.
19 changes: 10 additions & 9 deletions ceno_zkvm/src/chip_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use ff_ext::ExtensionField;
use crate::{
error::ZKVMError,
expression::{Expression, WitIn},
instructions::riscv::config::ExprLtConfig,
gadgets::IsLtConfig,
instructions::riscv::constants::UINT_LIMBS,
};

pub mod general;
Expand All @@ -19,8 +20,8 @@ pub trait GlobalStateRegisterMachineChipOperations<E: ExtensionField> {
}

/// The common representation of a register value.
/// Format: `[u16; 2]`, least-significant-first.
pub type RegisterExpr<E> = [Expression<E>; 2];
/// Format: `[u16; UINT_LIMBS]`, least-significant-first.
pub type RegisterExpr<E> = [Expression<E>; UINT_LIMBS];

pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> {
fn register_read(
Expand All @@ -30,7 +31,7 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
prev_ts: Expression<E>,
ts: Expression<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError>;
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError>;

#[allow(clippy::too_many_arguments)]
fn register_write(
Expand All @@ -41,12 +42,12 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
ts: Expression<E>,
prev_values: RegisterExpr<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError>;
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError>;
}

/// The common representation of a memory value.
/// Format: `[u16; 2]`, least-significant-first.
pub type MemoryExpr<E> = [Expression<E>; 2];
/// Format: `[u16; UINT_LIMBS]`, least-significant-first.
pub type MemoryExpr<E> = [Expression<E>; UINT_LIMBS];

pub trait MemoryChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> {
#[allow(dead_code)]
Expand All @@ -57,7 +58,7 @@ pub trait MemoryChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce()
prev_ts: Expression<E>,
ts: Expression<E>,
value: crate::chip_handler::MemoryExpr<E>,
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError>;
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError>;

#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
Expand All @@ -69,5 +70,5 @@ pub trait MemoryChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce()
ts: Expression<E>,
prev_values: crate::chip_handler::MemoryExpr<E>,
value: crate::chip_handler::MemoryExpr<E>,
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError>;
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError>;
}
68 changes: 5 additions & 63 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use std::fmt::Display;

use ff_ext::ExtensionField;
use itertools::Itertools;

use crate::{
chip_handler::utils::pows_expr,
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
gadgets::IsLtConfig,
structs::ROMType,
tables::InsnRecord,
};
Expand Down Expand Up @@ -255,11 +253,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR,
{
// TODO: Replace with `x * (1 - x)` or a multi-bit lookup similar to assert_u8_pair.
let items: Vec<Expression<E>> = vec![(ROMType::U1 as usize).into(), expr];
let rlc_record = self.rlc_chip_record(items);
self.lk_record(name_fn, rlc_record)?;
Ok(())
self.require_zero(name_fn, expr.clone() * (Expression::ONE - expr))
}

/// Assert `rom_type(a, b) = c` and that `a, b, c` are all bytes.
Expand Down Expand Up @@ -316,70 +310,18 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
}

/// less_than
pub(crate) fn less_than<N, NR>(
pub(crate) fn less_than<N, NR, const N_LIMBS: usize>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
assert_less_than: Option<bool>,
) -> Result<ExprLtConfig, ZKVMError>
) -> Result<IsLtConfig<N_LIMBS>, ZKVMError>
where
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
{
#[cfg(feature = "riv64")]
panic!("less_than is not supported for riv64 yet");

#[cfg(feature = "riv32")]
self.namespace(
|| "less_than",
|cb| {
let name = name_fn();
let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than {
(
None,
if lt {
Expression::ONE
} else {
Expression::ZERO
},
)
} else {
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
(Some(is_lt), is_lt.expr())
};

let mut witin_u16 = |var_name: String| -> Result<WitIn, ZKVMError> {
cb.namespace(
|| format!("var {var_name}"),
|cb| {
let witin = cb.create_witin(|| var_name.to_string())?;
cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?;
Ok(witin)
},
)
};

let diff = (0..2)
.map(|i| witin_u16(format!("diff_{i}")))
.collect::<Result<Vec<WitIn>, _>>()?;

let pows = pows_expr((1 << u16::BITS).into(), diff.len());

let diff_expr = diff
.iter()
.zip_eq(pows)
.map(|(record, beta)| beta * record.expr())
.reduce(|a, b| a + b)
.expect("reduce error");

let range = (1 << u32::BITS).into();

cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?;

Ok(ExprLtConfig { is_lt, diff })
},
)
IsLtConfig::construct_circuit(self, name_fn, lhs, rhs, assert_less_than)
}

pub(crate) fn is_equal(
Expand Down
7 changes: 4 additions & 3 deletions ceno_zkvm/src/chip_handler/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
gadgets::IsLtConfig,
instructions::riscv::constants::UINT_LIMBS,
structs::RAMType,
};
use ff_ext::ExtensionField;
Expand All @@ -19,7 +20,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> MemoryChipOpera
prev_ts: Expression<E>,
ts: Expression<E>,
value: MemoryExpr<E>,
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError> {
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down Expand Up @@ -66,7 +67,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> MemoryChipOpera
ts: Expression<E>,
prev_values: MemoryExpr<E>,
value: MemoryExpr<E>,
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError> {
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down
7 changes: 4 additions & 3 deletions ceno_zkvm/src/chip_handler/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
gadgets::IsLtConfig,
instructions::riscv::constants::UINT_LIMBS,
structs::RAMType,
};

Expand All @@ -20,7 +21,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
prev_ts: Expression<E>,
ts: Expression<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError> {
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down Expand Up @@ -66,7 +67,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
ts: Expression<E>,
prev_values: RegisterExpr<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError> {
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down
119 changes: 119 additions & 0 deletions ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use std::{fmt::Display, mem::MaybeUninit};

use ff_ext::ExtensionField;
use goldilocks::SmallField;
use itertools::Itertools;

use crate::{
chip_handler::utils::pows_expr,
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
set_val,
witness::LkMultiplicity,
};

#[derive(Debug)]
pub struct IsLtConfig<const N_U16: usize> {
pub is_lt: Option<WitIn>,
pub diff: [WitIn; N_U16],
}

impl<const N_U16: usize> IsLtConfig<N_U16> {
pub fn expr<E: ExtensionField>(&self) -> Expression<E> {
self.is_lt.unwrap().expr()
}

pub fn construct_circuit<
E: ExtensionField,
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
>(
cb: &mut CircuitBuilder<E>,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
assert_less_than: Option<bool>,
) -> Result<Self, ZKVMError> {
assert!(N_U16 >= 1);
cb.namespace(
|| "less_than",
|cb| {
let name = name_fn();
let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than {
(
None,
if lt {
Expression::ONE
} else {
Expression::ZERO
},
)
} else {
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
cb.assert_bit(|| "is_lt_bit", is_lt.expr())?;
(Some(is_lt), is_lt.expr())
};

let mut witin_u16 = |var_name: String| -> Result<WitIn, ZKVMError> {
cb.namespace(
|| format!("var {var_name}"),
|cb| {
let witin = cb.create_witin(|| var_name.to_string())?;
cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?;
Ok(witin)
},
)
};

let diff = (0..N_U16)
.map(|i| witin_u16(format!("diff_{i}")))
.collect::<Result<Vec<WitIn>, _>>()?;

let pows = pows_expr((1 << u16::BITS).into(), diff.len());

let diff_expr = diff
.iter()
.zip_eq(pows)
.map(|(record, beta)| beta * record.expr())
.reduce(|a, b| a + b)
.expect("reduce error");

let range = (1 << (N_U16 * u16::BITS as usize)).into();

cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?;

Ok(IsLtConfig {
is_lt,
diff: diff.try_into().unwrap(),
})
},
)
}

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
lkm: &mut LkMultiplicity,
lhs: u64,
rhs: u64,
) -> Result<(), ZKVMError> {
let is_lt = if let Some(is_lt_wit) = self.is_lt {
let is_lt = lhs < rhs;
set_val!(instance, is_lt_wit, is_lt as u64);
is_lt
} else {
// assert is_lt == true
true
};

let diff = if is_lt { 1u64 << u32::BITS } else { 0 } + lhs - rhs;
self.diff.iter().enumerate().for_each(|(i, wit)| {
// extract the 16 bit limb from diff and assign to instance
let val = (diff >> (i * u16::BITS as usize)) & 0xffff;
lkm.assert_ux::<16>(val);
set_val!(instance, wit, val);
});
Ok(())
}
}
28 changes: 18 additions & 10 deletions ceno_zkvm/src/gadgets/is_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@ impl IsZeroConfig {
self.is_zero.expr()
}

pub fn construct_circuit<E: ExtensionField>(
pub fn construct_circuit<E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR>(
cb: &mut CircuitBuilder<E>,
name_fn: N,
x: Expression<E>,
) -> Result<Self, ZKVMError> {
let is_zero = cb.create_witin(|| "is_zero")?;
let inverse = cb.create_witin(|| "inv")?;
cb.namespace(name_fn, |cb| {
let is_zero = cb.create_witin(|| "is_zero")?;
let inverse = cb.create_witin(|| "inv")?;

// x==0 => is_zero=1
cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?;
// x==0 => is_zero=1
cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?;

// x!=0 => is_zero=0
cb.require_zero(|| "is_zero_0", is_zero.expr() * x.clone())?;
// x!=0 => is_zero=0
cb.require_zero(|| "is_zero_0", is_zero.expr() * x.clone())?;

Ok(IsZeroConfig { is_zero, inverse })
Ok(IsZeroConfig { is_zero, inverse })
})
}

pub fn assign_instance<F: SmallField>(
Expand Down Expand Up @@ -61,12 +64,17 @@ impl IsEqualConfig {
self.0.expr()
}

pub fn construct_circuit<E: ExtensionField>(
pub fn construct_circuit<E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR>(
cb: &mut CircuitBuilder<E>,
name_fn: N,
a: Expression<E>,
b: Expression<E>,
) -> Result<Self, ZKVMError> {
Ok(IsEqualConfig(IsZeroConfig::construct_circuit(cb, a - b)?))
Ok(IsEqualConfig(IsZeroConfig::construct_circuit(
cb,
name_fn,
a - b,
)?))
}

pub fn assign_instance<F: SmallField>(
Expand Down
2 changes: 2 additions & 0 deletions ceno_zkvm/src/gadgets/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
mod is_lt;
mod is_zero;
pub use is_lt::IsLtConfig;
pub use is_zero::{IsEqualConfig, IsZeroConfig};
Loading

0 comments on commit 5f4713d

Please sign in to comment.