Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

less_than util #183

Merged
merged 17 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ceno_zkvm/src/chip_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use ff_ext::ExtensionField;
use crate::{
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
};

pub mod general;
Expand All @@ -23,8 +24,9 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
prev_ts: Expression<E>,
ts: Expression<E>,
values: &V,
) -> Result<Expression<E>, ZKVMError>;
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError>;

#[allow(clippy::too_many_arguments)]
fn register_write<V: ToExpr<E, Output = Vec<Expression<E>>>>(
hero78119 marked this conversation as resolved.
Show resolved Hide resolved
&mut self,
name_fn: N,
Expand All @@ -33,5 +35,5 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
ts: Expression<E>,
prev_values: &V,
values: &V,
) -> Result<Expression<E>, ZKVMError>;
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError>;
}
69 changes: 69 additions & 0 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Display;

use ff_ext::ExtensionField;

use ff::Field;
Expand All @@ -6,6 +8,7 @@ use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
structs::ROMType,
};

Expand Down Expand Up @@ -264,6 +267,72 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Ok(())
}

/// less_than
pub(crate) fn less_than<N, NR>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
assert_less_than: Option<bool>,
) -> Result<ExprLtConfig, ZKVMError>
where
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
{
#[cfg(feature = "riv64")]
panic!("less_than is not supported for riv64 yet");

#[cfg(feature = "riv32")]
self.namespace(
|| "less_than",
|cb| {
let name = name_fn();
let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than {
(
None,
if lt {
Expression::ONE
} else {
Expression::ZERO
},
)
} else {
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
(Some(is_lt), is_lt.expr())
};

let mut witin_u16 = |var_name: String| -> Result<WitIn, ZKVMError> {
cb.namespace(
|| format!("var {var_name}"),
|cb| {
let witin = cb.create_witin(|| var_name.to_string())?;
cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?;
Ok(witin)
},
)
};

let diff = (0..2)
.map(|i| witin_u16(format!("diff_{i}")))
.collect::<Result<Vec<WitIn>, _>>()?;

let diff_expr = diff
.iter()
.enumerate()
.map(|(i, diff)| (i, diff.expr()))
.fold(Expression::ZERO, |sum, (i, a)| {
sum + if i > 0 { a * (1 << (16 * i)).into() } else { a }
});

let range = Expression::Constant((1 << 32).into());

cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?;

Ok(ExprLtConfig { is_lt, diff })
},
)
}

pub(crate) fn is_equal(
&mut self,
lhs: Expression<E>,
Expand Down
24 changes: 11 additions & 13 deletions ceno_zkvm/src/chip_handler/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
instructions::riscv::config::ExprLtConfig,
structs::RAMType,
};

Expand All @@ -19,7 +20,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
prev_ts: Expression<E>,
ts: Expression<E>,
values: &V,
) -> Result<Expression<E>, ZKVMError> {
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand All @@ -29,7 +30,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
))],
vec![register_id.expr()],
values.expr(),
vec![prev_ts],
vec![prev_ts.clone()],
]
.concat(),
);
Expand All @@ -49,12 +50,11 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
cb.write_record(|| "write_record", write_record)?;

// assert prev_ts < current_ts
// TODO implement lt gadget
// let is_lt = prev_ts.lt(self, ts)?;
// self.require_one(is_lt)?;
let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;

let next_ts = ts + 1.into();

Ok(next_ts)
Ok((next_ts, lt_cfg))
})
}

Expand All @@ -66,7 +66,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
ts: Expression<E>,
prev_values: &V,
values: &V,
) -> Result<Expression<E>, ZKVMError> {
) -> Result<(Expression<E>, ExprLtConfig), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand All @@ -76,7 +76,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
))],
vec![register_id.expr()],
prev_values.expr(),
vec![prev_ts],
vec![prev_ts.clone()],
]
.concat(),
);
Expand All @@ -95,13 +95,11 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
cb.read_record(|| "read_record", read_record)?;
cb.write_record(|| "write_record", write_record)?;

// assert prev_ts < current_ts
// TODO implement lt gadget
// let is_lt = prev_ts.lt(self, ts)?;
// self.require_one(is_lt)?;
let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;

let next_ts = ts + 1.into();

Ok(next_ts)
Ok((next_ts, lt_cfg))
})
}
}
3 changes: 3 additions & 0 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ enum MonomialState {
}

impl<E: ExtensionField> Expression<E> {
pub const ZERO: Expression<E> = Expression::Constant(E::BaseField::ZERO);
pub const ONE: Expression<E> = Expression::Constant(E::BaseField::ONE);

pub fn degree(&self) -> usize {
match self {
Expression::Fixed(_) => 1,
Expand Down
36 changes: 30 additions & 6 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use ff_ext::ExtensionField;
use itertools::Itertools;

use super::{
config::ExprLtConfig,
constants::{OPType, OpcodeType, RegUInt, PC_STEP_SIZE},
RIVInstruction,
};
Expand All @@ -13,7 +14,7 @@ use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{ToExpr, WitIn},
instructions::Instruction,
instructions::{riscv::config::ExprLtInput, Instruction},
set_val,
uint::UIntValue,
witness::LkMultiplicity,
Expand All @@ -37,6 +38,9 @@ pub struct InstructionConfig<E: ExtensionField> {
pub prev_rs1_ts: WitIn,
pub prev_rs2_ts: WitIn,
pub prev_rd_ts: WitIn,
pub lt_rs1_cfg: ExprLtConfig,
pub lt_rs2_cfg: ExprLtConfig,
pub lt_prev_ts_cfg: ExprLtConfig,
phantom: PhantomData<E>,
}

Expand Down Expand Up @@ -99,17 +103,17 @@ fn add_sub_gadget<E: ExtensionField, const IS_ADD: bool>(
let prev_rs2_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?;
let prev_rd_ts = circuit_builder.create_witin(|| "prev_rd_ts")?;

let ts = circuit_builder.register_read(
let (ts, lt_rs1_cfg) = circuit_builder.register_read(
|| "read_rs1",
&rs1_id,
prev_rs1_ts.expr(),
cur_ts.expr(),
&addend_0,
)?;
let ts =
let (ts, lt_rs2_cfg) =
circuit_builder.register_read(|| "read_rs2", &rs2_id, prev_rs2_ts.expr(), ts, &addend_1)?;

let ts = circuit_builder.register_write(
let (ts, lt_prev_ts_cfg) = circuit_builder.register_write(
|| "write_rd",
&rd_id,
prev_rd_ts.expr(),
Expand All @@ -134,6 +138,9 @@ fn add_sub_gadget<E: ExtensionField, const IS_ADD: bool>(
prev_rs1_ts,
prev_rs2_ts,
prev_rd_ts,
lt_rs1_cfg,
lt_rs2_cfg,
lt_prev_ts_cfg,
phantom: PhantomData,
})
}
Expand All @@ -159,7 +166,7 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction<E> {
) -> Result<(), ZKVMError> {
// TODO use fields from step
set_val!(instance, config.pc, 1);
set_val!(instance, config.ts, 2);
set_val!(instance, config.ts, 3);
let addend_0 = UIntValue::new_unchecked(step.rs1().unwrap().value);
let addend_1 = UIntValue::new_unchecked(step.rs2().unwrap().value);
let rd_prev = UIntValue::new_unchecked(step.rd().unwrap().value.before);
Expand Down Expand Up @@ -187,6 +194,23 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction<E> {
set_val!(instance, config.prev_rs1_ts, 2);
set_val!(instance, config.prev_rs2_ts, 2);
set_val!(instance, config.prev_rd_ts, 2);

ExprLtInput {
lhs: 2, // rs1
rhs: 3, // cur_ts
}
.assign(instance, &config.lt_rs1_cfg);
ExprLtInput {
lhs: 2, // rs2
rhs: 4, // cur_ts
}
.assign(instance, &config.lt_rs2_cfg);
ExprLtInput {
lhs: 2, // rd
rhs: 5, // cur_ts
}
.assign(instance, &config.lt_prev_ts_cfg);

Ok(())
}
}
Expand Down Expand Up @@ -362,7 +386,7 @@ mod test {
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
Some([100.into(), 100000.into()]),
);
}
}
17 changes: 11 additions & 6 deletions ceno_zkvm/src/instructions/riscv/blt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
error::ZKVMError,
expression::{ToExpr, WitIn},
instructions::{
riscv::config::{LtConfig, LtInput},
riscv::config::{UIntLtConfig, UIntLtInput},
Instruction,
},
set_val,
Expand All @@ -19,6 +19,7 @@ use crate::{
};

use super::{
config::ExprLtConfig,
constants::{OPType, OpcodeType, RegUInt, RegUInt8, PC_STEP_SIZE},
RIVInstruction,
};
Expand All @@ -38,7 +39,9 @@ pub struct InstructionConfig<E: ExtensionField> {
pub rs2_id: WitIn,
pub prev_rs1_ts: WitIn,
pub prev_rs2_ts: WitIn,
pub is_lt: LtConfig,
pub is_lt: UIntLtConfig,
pub lt_rs1_cfg: ExprLtConfig,
pub lt_rs2_cfg: ExprLtConfig,
}

pub struct BltInput {
Expand All @@ -62,7 +65,7 @@ impl BltInput {
) {
assert!(!self.lhs_limb8.is_empty() && (self.lhs_limb8.len() == self.rhs_limb8.len()));
// TODO: add boundary check for witin
let lt_input = LtInput {
let lt_input = UIntLtInput {
lhs_limbs: &self.lhs_limb8,
rhs_limbs: &self.rhs_limb8,
};
Expand Down Expand Up @@ -175,14 +178,14 @@ fn blt_gadget<E: ExtensionField>(
let lhs = RegUInt::from_u8_limbs(circuit_builder, &lhs_limb8);
let rhs = RegUInt::from_u8_limbs(circuit_builder, &rhs_limb8);

let ts = circuit_builder.register_read(
let (ts, lt_rs1_cfg) = circuit_builder.register_read(
|| "read ts for lhs",
&rs1_id,
prev_rs1_ts.expr(),
cur_ts.expr(),
&lhs,
)?;
let ts = circuit_builder.register_read(
let (ts, lt_rs2_cfg) = circuit_builder.register_read(
|| "read ts for rhs",
&rs2_id,
prev_rs2_ts.expr(),
Expand All @@ -208,6 +211,8 @@ fn blt_gadget<E: ExtensionField>(
prev_rs1_ts,
prev_rs2_ts,
is_lt,
lt_rs1_cfg,
lt_rs2_cfg,
})
}

Expand Down Expand Up @@ -270,7 +275,7 @@ mod test {
.into_iter()
.map(|v| v.into())
.collect_vec(),
None,
Some([1.into(), 1000.into()]),
)
.expect_err("lookup will fail");
Ok(())
Expand Down
Loading
Loading