Skip to content

Commit

Permalink
less_than util (#183)
Browse files Browse the repository at this point in the history
Closes #167
  • Loading branch information
zemse authored Sep 12, 2024
1 parent 534269e commit 67e530b
Show file tree
Hide file tree
Showing 9 changed files with 463 additions and 52 deletions.
6 changes: 4 additions & 2 deletions ceno_zkvm/src/chip_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use ff_ext::ExtensionField;
use crate::{
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
};

pub mod general;
Expand All @@ -23,8 +24,9 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
prev_ts: Expression<E>,
ts: Expression<E>,
values: &V,
) -> Result<Expression<E>, ZKVMError>;
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError>;

#[allow(clippy::too_many_arguments)]
fn register_write<V: ToExpr<E, Output = Vec<Expression<E>>>>(
&mut self,
name_fn: N,
Expand All @@ -33,5 +35,5 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
ts: Expression<E>,
prev_values: &V,
values: &V,
) -> Result<Expression<E>, ZKVMError>;
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError>;
}
69 changes: 69 additions & 0 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Display;

use ff_ext::ExtensionField;

use ff::Field;
Expand All @@ -6,6 +8,7 @@ use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
structs::ROMType,
};

Expand Down Expand Up @@ -264,6 +267,72 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Ok(())
}

/// less_than
pub(crate) fn less_than<N, NR>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
assert_less_than: Option<bool>,
) -> Result<ExprLtConfig, ZKVMError>
where
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
{
#[cfg(feature = "riv64")]
panic!("less_than is not supported for riv64 yet");

#[cfg(feature = "riv32")]
self.namespace(
|| "less_than",
|cb| {
let name = name_fn();
let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than {
(
None,
if lt {
Expression::ONE
} else {
Expression::ZERO
},
)
} else {
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
(Some(is_lt), is_lt.expr())
};

let mut witin_u16 = |var_name: String| -> Result<WitIn, ZKVMError> {
cb.namespace(
|| format!("var {var_name}"),
|cb| {
let witin = cb.create_witin(|| var_name.to_string())?;
cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?;
Ok(witin)
},
)
};

let diff = (0..2)
.map(|i| witin_u16(format!("diff_{i}")))
.collect::<Result<Vec<WitIn>, _>>()?;

let diff_expr = diff
.iter()
.enumerate()
.map(|(i, diff)| (i, diff.expr()))
.fold(Expression::ZERO, |sum, (i, a)| {
sum + if i > 0 { a * (1 << (16 * i)).into() } else { a }
});

let range = Expression::Constant((1 << 32).into());

cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?;

Ok(ExprLtConfig { is_lt, diff })
},
)
}

pub(crate) fn is_equal(
&mut self,
lhs: Expression<E>,
Expand Down
24 changes: 11 additions & 13 deletions ceno_zkvm/src/chip_handler/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
structs::RAMType,
};

Expand All @@ -19,7 +20,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
prev_ts: Expression<E>,
ts: Expression<E>,
values: &V,
) -> Result<Expression<E>, ZKVMError> {
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand All @@ -29,7 +30,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
))],
vec![register_id.expr()],
values.expr(),
vec![prev_ts],
vec![prev_ts.clone()],
]
.concat(),
);
Expand All @@ -49,12 +50,11 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
cb.write_record(|| "write_record", write_record)?;

// assert prev_ts < current_ts
// TODO implement lt gadget
// let is_lt = prev_ts.lt(self, ts)?;
// self.require_one(is_lt)?;
let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;

let next_ts = ts + 1.into();

Ok(next_ts)
Ok((next_ts, lt_cfg))
})
}

Expand All @@ -66,7 +66,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
ts: Expression<E>,
prev_values: &V,
values: &V,
) -> Result<Expression<E>, ZKVMError> {
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand All @@ -76,7 +76,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
))],
vec![register_id.expr()],
prev_values.expr(),
vec![prev_ts],
vec![prev_ts.clone()],
]
.concat(),
);
Expand All @@ -95,13 +95,11 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
cb.read_record(|| "read_record", read_record)?;
cb.write_record(|| "write_record", write_record)?;

// assert prev_ts < current_ts
// TODO implement lt gadget
// let is_lt = prev_ts.lt(self, ts)?;
// self.require_one(is_lt)?;
let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;

let next_ts = ts + 1.into();

Ok(next_ts)
Ok((next_ts, lt_cfg))
})
}
}
3 changes: 3 additions & 0 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ enum MonomialState {
}

impl<E: ExtensionField> Expression<E> {
pub const ZERO: Expression<E> = Expression::Constant(E::BaseField::ZERO);
pub const ONE: Expression<E> = Expression::Constant(E::BaseField::ONE);

pub fn degree(&self) -> usize {
match self {
Expression::Fixed(_) => 1,
Expand Down
36 changes: 30 additions & 6 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use ff_ext::ExtensionField;
use itertools::Itertools;

use super::{
config::ExprLtConfig,
constants::{OPType, OpcodeType, RegUInt, PC_STEP_SIZE},
RIVInstruction,
};
Expand All @@ -13,7 +14,7 @@ use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{ToExpr, WitIn},
instructions::Instruction,
instructions::{riscv::config::ExprLtInput, Instruction},
set_val,
uint::UIntValue,
witness::LkMultiplicity,
Expand All @@ -37,6 +38,9 @@ pub struct InstructionConfig<E: ExtensionField> {
pub prev_rs1_ts: WitIn,
pub prev_rs2_ts: WitIn,
pub prev_rd_ts: WitIn,
pub lt_rs1_cfg: ExprLtConfig,
pub lt_rs2_cfg: ExprLtConfig,
pub lt_prev_ts_cfg: ExprLtConfig,
phantom: PhantomData<E>,
}

Expand Down Expand Up @@ -99,17 +103,17 @@ fn add_sub_gadget<E: ExtensionField, const IS_ADD: bool>(
let prev_rs2_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?;
let prev_rd_ts = circuit_builder.create_witin(|| "prev_rd_ts")?;

let ts = circuit_builder.register_read(
let (ts, lt_rs1_cfg) = circuit_builder.register_read(
|| "read_rs1",
&rs1_id,
prev_rs1_ts.expr(),
cur_ts.expr(),
&addend_0,
)?;
let ts =
let (ts, lt_rs2_cfg) =
circuit_builder.register_read(|| "read_rs2", &rs2_id, prev_rs2_ts.expr(), ts, &addend_1)?;

let ts = circuit_builder.register_write(
let (ts, lt_prev_ts_cfg) = circuit_builder.register_write(
|| "write_rd",
&rd_id,
prev_rd_ts.expr(),
Expand All @@ -134,6 +138,9 @@ fn add_sub_gadget<E: ExtensionField, const IS_ADD: bool>(
prev_rs1_ts,
prev_rs2_ts,
prev_rd_ts,
lt_rs1_cfg,
lt_rs2_cfg,
lt_prev_ts_cfg,
phantom: PhantomData,
})
}
Expand All @@ -159,7 +166,7 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction<E> {
) -> Result<(), ZKVMError> {
// TODO use fields from step
set_val!(instance, config.pc, 1);
set_val!(instance, config.ts, 2);
set_val!(instance, config.ts, 3);
let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value);
let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value);
let rd_prev = UIntValue::new_unchecked(step.rd().unwrap().value.before);
Expand Down Expand Up @@ -187,6 +194,23 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction<E> {
set_val!(instance, config.prev_rs1_ts, 2);
set_val!(instance, config.prev_rs2_ts, 2);
set_val!(instance, config.prev_rd_ts, 2);

ExprLtInput {
lhs: 2, // rs1
rhs: 3, // cur_ts
}
.assign(instance, &config.lt_rs1_cfg);
ExprLtInput {
lhs: 2, // rs2
rhs: 4, // cur_ts
}
.assign(instance, &config.lt_rs2_cfg);
ExprLtInput {
lhs: 2, // rd
rhs: 5, // cur_ts
}
.assign(instance, &config.lt_prev_ts_cfg);

Ok(())
}
}
Expand Down Expand Up @@ -362,7 +386,7 @@ mod test {
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
Some([100.into(), 100000.into()]),
);
}
}
17 changes: 11 additions & 6 deletions ceno_zkvm/src/instructions/riscv/blt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
error::ZKVMError,
expression::{ToExpr, WitIn},
instructions::{
riscv::config::{LtConfig, LtInput},
riscv::config::{UIntLtConfig, UIntLtInput},
Instruction,
},
set_val,
Expand All @@ -19,6 +19,7 @@ use crate::{
};

use super::{
config::ExprLtConfig,
constants::{OPType, OpcodeType, RegUInt, RegUInt8, PC_STEP_SIZE},
RIVInstruction,
};
Expand All @@ -38,7 +39,9 @@ pub struct InstructionConfig<E: ExtensionField> {
pub rs2_id: WitIn,
pub prev_rs1_ts: WitIn,
pub prev_rs2_ts: WitIn,
pub is_lt: LtConfig,
pub is_lt: UIntLtConfig,
pub lt_rs1_cfg: ExprLtConfig,
pub lt_rs2_cfg: ExprLtConfig,
}

pub struct BltInput {
Expand All @@ -62,7 +65,7 @@ impl BltInput {
) {
assert!(!self.lhs_limb8.is_empty() && (self.lhs_limb8.len() == self.rhs_limb8.len()));
// TODO: add boundary check for witin
let lt_input = LtInput {
let lt_input = UIntLtInput {
lhs_limbs: &self.lhs_limb8,
rhs_limbs: &self.rhs_limb8,
};
Expand Down Expand Up @@ -175,14 +178,14 @@ fn blt_gadget<E: ExtensionField>(
let lhs = RegUInt::from_u8_limbs(circuit_builder, &lhs_limb8);
let rhs = RegUInt::from_u8_limbs(circuit_builder, &rhs_limb8);

let ts = circuit_builder.register_read(
let (ts, lt_rs1_cfg) = circuit_builder.register_read(
|| "read ts for lhs",
&rs1_id,
prev_rs1_ts.expr(),
cur_ts.expr(),
&lhs,
)?;
let ts = circuit_builder.register_read(
let (ts, lt_rs2_cfg) = circuit_builder.register_read(
|| "read ts for rhs",
&rs2_id,
prev_rs2_ts.expr(),
Expand All @@ -208,6 +211,8 @@ fn blt_gadget<E: ExtensionField>(
prev_rs1_ts,
prev_rs2_ts,
is_lt,
lt_rs1_cfg,
lt_rs2_cfg,
})
}

Expand Down Expand Up @@ -270,7 +275,7 @@ mod test {
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
Some([1.into(), 1000.into()]),
)
.expect_err("lookup will fail");
Ok(())
Expand Down
Loading

0 comments on commit 67e530b

Please sign in to comment.