Skip to content

Commit

Permalink
Recognise create_witin can't fail (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasgoergens authored Oct 29, 2024
1 parent 045338d commit a74cbf3
Show file tree
Hide file tree
Showing 29 changed files with 90 additions and 92 deletions.
10 changes: 5 additions & 5 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Self { cs }
}

pub fn create_witin<NR, N>(&mut self, name_fn: N) -> Result<WitIn, ZKVMError>
pub fn create_witin<NR, N>(&mut self, name_fn: N) -> WitIn
where
NR: Into<String>,
N: FnOnce() -> NR,
Expand Down Expand Up @@ -148,7 +148,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR + Clone,
{
let byte = self.cs.create_witin(name_fn.clone())?;
let byte = self.cs.create_witin(name_fn.clone());
self.assert_ux::<_, _, 8>(name_fn, byte.expr())?;

Ok(byte)
Expand All @@ -159,7 +159,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
NR: Into<String>,
N: FnOnce() -> NR + Clone,
{
let limb = self.cs.create_witin(name_fn.clone())?;
let limb = self.cs.create_witin(name_fn.clone());
self.assert_ux::<_, _, 16>(name_fn, limb.expr())?;

Ok(limb)
Expand Down Expand Up @@ -393,8 +393,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
lhs: Expression<E>,
rhs: Expression<E>,
) -> Result<(WitIn, WitIn), ZKVMError> {
let is_eq = self.create_witin(|| "is_eq")?;
let diff_inverse = self.create_witin(|| "diff_inverse")?;
let is_eq = self.create_witin(|| "is_eq");
let diff_inverse = self.create_witin(|| "diff_inverse");

self.require_zero(
|| "is equal",
Expand Down
9 changes: 3 additions & 6 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,22 +194,19 @@ impl<E: ExtensionField> ConstraintSystem<E> {
}
}

pub fn create_witin<NR: Into<String>, N: FnOnce() -> NR>(
&mut self,
n: N,
) -> Result<WitIn, ZKVMError> {
pub fn create_witin<NR: Into<String>, N: FnOnce() -> NR>(&mut self, n: N) -> WitIn {
let wit_in = WitIn {
id: {
let id = self.num_witin;
self.num_witin += 1;
self.num_witin = self.num_witin.strict_add(1);
id
},
};

let path = self.ns.compute_path(n().into());
self.witin_namespace_map.push(path);

Ok(wit_in)
wit_in
}

pub fn create_fixed<NR: Into<String>, N: FnOnce() -> NR>(
Expand Down
14 changes: 7 additions & 7 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ impl WitIn {
|| "from_expr",
|cb| {
let name = name().into();
let wit = cb.create_witin(|| name.clone())?;
let wit = cb.create_witin(|| name.clone());
if !debug {
cb.require_zero(|| name.clone(), wit.expr() - input)?;
}
Expand Down Expand Up @@ -876,7 +876,7 @@ mod tests {
type E = GoldilocksExt2;
let mut cs = ConstraintSystem::new(|| "test_root");
let mut cb = CircuitBuilder::<E>::new(&mut cs);
let x = cb.create_witin(|| "x").unwrap();
let x = cb.create_witin(|| "x");

// scaledsum * challenge
// 3 * x + 2
Expand Down Expand Up @@ -942,9 +942,9 @@ mod tests {
type E = GoldilocksExt2;
let mut cs = ConstraintSystem::new(|| "test_root");
let mut cb = CircuitBuilder::<E>::new(&mut cs);
let x = cb.create_witin(|| "x").unwrap();
let y = cb.create_witin(|| "y").unwrap();
let z = cb.create_witin(|| "z").unwrap();
let x = cb.create_witin(|| "x");
let y = cb.create_witin(|| "y");
let z = cb.create_witin(|| "z");
// scaledsum * challenge
// 3 * x + 2
let expr: Expression<E> =
Expand Down Expand Up @@ -984,8 +984,8 @@ mod tests {
type E = GoldilocksExt2;
let mut cs = ConstraintSystem::new(|| "test_root");
let mut cb = CircuitBuilder::<E>::new(&mut cs);
let x = cb.create_witin(|| "x").unwrap();
let y = cb.create_witin(|| "y").unwrap();
let x = cb.create_witin(|| "x");
let y = cb.create_witin(|| "y");
// scaledsum * challenge
// (x + 1) * (y + 1)
let expr: Expression<E> = (Into::<Expression<E>>::into(1usize) + x.expr())
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl IsLtConfig {
|| "is_lt",
|cb| {
let name = name_fn();
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"))?;
let is_lt = cb.create_witin(|| format!("{name} is_lt witin"));
cb.assert_bit(|| "is_lt_bit", is_lt.expr())?;

let config = InnerLtConfig::construct_circuit(
Expand Down Expand Up @@ -153,7 +153,7 @@ impl InnerLtConfig {
cb.namespace(
|| format!("var {var_name}"),
|cb| {
let witin = cb.create_witin(|| var_name.to_string())?;
let witin = cb.create_witin(|| var_name.to_string());
cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?;
Ok(witin)
},
Expand Down Expand Up @@ -293,7 +293,7 @@ impl SignedLtConfig {
|| "is_signed_lt",
|cb| {
let name = name_fn();
let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin"))?;
let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin"));
cb.assert_bit(|| "is_lt_bit", is_lt.expr())?;
let config =
InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr())?;
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/gadgets/is_zero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ impl IsZeroConfig {
x: Expression<E>,
) -> Result<Self, ZKVMError> {
cb.namespace(name_fn, |cb| {
let is_zero = cb.create_witin(|| "is_zero")?;
let inverse = cb.create_witin(|| "inv")?;
let is_zero = cb.create_witin(|| "is_zero");
let inverse = cb.create_witin(|| "inv");

// x==0 => is_zero=1
cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?;
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/gadgets/signed_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl SignedExtendConfig {
) -> Result<Self, ZKVMError> {
assert!(n_bits == 8 || n_bits == 16);

let msb = cb.create_witin(|| "msb")?;
let msb = cb.create_witin(|| "msb");
// require msb is boolean
cb.assert_bit(|| "msb is boolean", msb.expr())?;

Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/b_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl<E: ExtensionField> BInstructionConfig<E> {
let rs2 = ReadRS2::construct_circuit(circuit_builder, rs2_read, vm_state.ts)?;

// Immediate
let imm = circuit_builder.create_witin(|| "imm")?;
let imm = circuit_builder.create_witin(|| "imm");

// Fetch instruction
circuit_builder.lk_fetch(&InsnRecord::new(
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/ecall/halt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl<E: ExtensionField> Instruction<E> for HaltInstruction<E> {
}

fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<Self::InstructionConfig, ZKVMError> {
let prev_x10_ts = cb.create_witin(|| "prev_x10_ts")?;
let prev_x10_ts = cb.create_witin(|| "prev_x10_ts");
let exit_code = {
let exit_code = cb.query_exit_code()?;
[exit_code[0].expr(), exit_code[1].expr()]
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions/riscv/ecall_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ impl EcallInstructionConfig {
syscall_ret_value: Option<RegisterExpr<E>>,
next_pc: Option<Expression<E>>,
) -> Result<Self, ZKVMError> {
let pc = cb.create_witin(|| "pc")?;
let ts = cb.create_witin(|| "cur_ts")?;
let pc = cb.create_witin(|| "pc");
let ts = cb.create_witin(|| "cur_ts");

cb.state_in(pc.expr(), ts.expr())?;
cb.state_out(
Expand All @@ -47,7 +47,7 @@ impl EcallInstructionConfig {
0.into(), // imm = 0
))?;

let prev_x5_ts = cb.create_witin(|| "prev_x5_ts")?;
let prev_x5_ts = cb.create_witin(|| "prev_x5_ts");

// read syscall_id from x5 and write return value to x5
let (_, lt_x5_cfg) = cb.register_write(
Expand Down
24 changes: 12 additions & 12 deletions ceno_zkvm/src/instructions/riscv/insn_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ impl<E: ExtensionField> StateInOut<E> {
circuit_builder: &mut CircuitBuilder<E>,
branching: bool,
) -> Result<Self, ZKVMError> {
let pc = circuit_builder.create_witin(|| "pc")?;
let pc = circuit_builder.create_witin(|| "pc");
let (next_pc_opt, next_pc_expr) = if branching {
let next_pc = circuit_builder.create_witin(|| "next_pc")?;
let next_pc = circuit_builder.create_witin(|| "next_pc");
(Some(next_pc), next_pc.expr())
} else {
(None, pc.expr() + PC_STEP_SIZE)
};
let ts = circuit_builder.create_witin(|| "ts")?;
let ts = circuit_builder.create_witin(|| "ts");
let next_ts = ts.expr() + Tracer::SUBCYCLES_PER_INSN;
circuit_builder.state_in(pc.expr(), ts.expr())?;
circuit_builder.state_out(next_pc_expr, next_ts)?;
Expand Down Expand Up @@ -87,8 +87,8 @@ impl<E: ExtensionField> ReadRS1<E> {
rs1_read: RegisterExpr<E>,
cur_ts: WitIn,
) -> Result<Self, ZKVMError> {
let id = circuit_builder.create_witin(|| "rs1_id")?;
let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?;
let id = circuit_builder.create_witin(|| "rs1_id");
let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts");
let (_, lt_cfg) = circuit_builder.register_read(
|| "read_rs1",
id,
Expand Down Expand Up @@ -142,8 +142,8 @@ impl<E: ExtensionField> ReadRS2<E> {
rs2_read: RegisterExpr<E>,
cur_ts: WitIn,
) -> Result<Self, ZKVMError> {
let id = circuit_builder.create_witin(|| "rs2_id")?;
let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?;
let id = circuit_builder.create_witin(|| "rs2_id");
let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts");
let (_, lt_cfg) = circuit_builder.register_read(
|| "read_rs2",
id,
Expand Down Expand Up @@ -197,8 +197,8 @@ impl<E: ExtensionField> WriteRD<E> {
rd_written: RegisterExpr<E>,
cur_ts: WitIn,
) -> Result<Self, ZKVMError> {
let id = circuit_builder.create_witin(|| "rd_id")?;
let prev_ts = circuit_builder.create_witin(|| "prev_rd_ts")?;
let id = circuit_builder.create_witin(|| "rd_id");
let prev_ts = circuit_builder.create_witin(|| "prev_rd_ts");
let prev_value = UInt::new_unchecked(|| "prev_rd_value", circuit_builder)?;
let (_, lt_cfg) = circuit_builder.register_write(
|| "write_rd",
Expand Down Expand Up @@ -258,7 +258,7 @@ impl<E: ExtensionField> ReadMEM<E> {
mem_read: Expression<E>,
cur_ts: WitIn,
) -> Result<Self, ZKVMError> {
let prev_ts = circuit_builder.create_witin(|| "prev_ts")?;
let prev_ts = circuit_builder.create_witin(|| "prev_ts");
let (_, lt_cfg) = circuit_builder.memory_read(
|| "read_memory",
&mem_addr,
Expand Down Expand Up @@ -313,7 +313,7 @@ impl WriteMEM {
new_value: MemoryExpr<E>,
cur_ts: WitIn,
) -> Result<Self, ZKVMError> {
let prev_ts = circuit_builder.create_witin(|| "prev_ts")?;
let prev_ts = circuit_builder.create_witin(|| "prev_ts");

let (_, lt_cfg) = circuit_builder.memory_write(
|| "write_memory",
Expand Down Expand Up @@ -408,7 +408,7 @@ impl<E: ExtensionField> MemAddr<E> {
// Witness and constrain the non-zero low bits.
let low_bits = (n_zeros..Self::N_LOW_BITS)
.map(|i| {
let bit = cb.create_witin(|| format!("addr_bit_{}", i))?;
let bit = cb.create_witin(|| format!("addr_bit_{}", i));
cb.assert_bit(|| format!("addr_bit_{}", i), bit.expr())?;
Ok(bit)
})
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/jump/auipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<E: ExtensionField> Instruction<E> for AuipcInstruction<E> {
fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<AuipcConfig<E>, ZKVMError> {
let imm = circuit_builder.create_witin(|| "imm")?;
let imm = circuit_builder.create_witin(|| "imm");
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;

let u_insn = UInstructionConfig::construct_circuit(
Expand All @@ -46,7 +46,7 @@ impl<E: ExtensionField> Instruction<E> for AuipcInstruction<E> {
rd_written.register_expr(),
)?;

let overflow_bit = circuit_builder.create_witin(|| "overflow_bit")?;
let overflow_bit = circuit_builder.create_witin(|| "overflow_bit");
circuit_builder.assert_bit(|| "is_bit", overflow_bit.expr())?;

// assert: imm + pc = rd_written + overflow_bit * 2^32
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/jump/jalr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<JalrConfig<E>, ZKVMError> {
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value
let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value
let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;

let i_insn = IInstructionConfig::construct_circuit(
Expand All @@ -63,7 +63,7 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
// 3. next_pc = next_pc_addr aligned to even value (round down)

let next_pc_addr = MemAddr::<E>::construct_unaligned(circuit_builder)?;
let overflow = circuit_builder.create_witin(|| "overflow")?;
let overflow = circuit_builder.create_witin(|| "overflow");

circuit_builder.require_equal(
|| "rs1+imm = next_pc_unrounded + overflow*2^32",
Expand Down
8 changes: 4 additions & 4 deletions ceno_zkvm/src/instructions/riscv/memory/gadget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl<const N_ZEROS: usize> MemWordChange<N_ZEROS> {
-> Result<Vec<WitIn>, ZKVMError> {
(0..num_bytes)
.map(|i| {
let byte = cb.create_witin(|| format!("{}.le_bytes[{}]", anno, i))?;
let byte = cb.create_witin(|| format!("{}.le_bytes[{}]", anno, i));
cb.assert_ux::<_, _, 8>(|| "byte range check", byte.expr())?;

Ok(byte)
Expand Down Expand Up @@ -84,7 +84,7 @@ impl<const N_ZEROS: usize> MemWordChange<N_ZEROS> {
)?;

// alloc a new witIn to cache degree 2 expression
let expected_limb_change = cb.create_witin(|| "expected_limb_change")?;
let expected_limb_change = cb.create_witin(|| "expected_limb_change");
cb.condition_require_equal(
|| "expected_limb_change = select(low_bits[0], rs2 - prev)",
low_bits[0].clone(),
Expand All @@ -94,7 +94,7 @@ impl<const N_ZEROS: usize> MemWordChange<N_ZEROS> {
)?;

// alloc a new witIn to cache degree 2 expression
let expected_change = cb.create_witin(|| "expected_change")?;
let expected_change = cb.create_witin(|| "expected_change");
cb.condition_require_equal(
|| "expected_change = select(low_bits[1], limb_change*2^16, limb_change)",
low_bits[1].clone(),
Expand All @@ -117,7 +117,7 @@ impl<const N_ZEROS: usize> MemWordChange<N_ZEROS> {
let prev_limbs = prev_word.expr();
let rs2_limbs = rs2_word.expr();

let expected_change = cb.create_witin(|| "expected_change")?;
let expected_change = cb.create_witin(|| "expected_change");

// alloc a new witIn to cache degree 2 expression
cb.condition_require_equal(
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/memory/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for LoadInstruction<E,
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value
let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value
let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value
let memory_read = UInt::new(|| "memory_read", circuit_builder)?;

let memory_addr = match I::INST_KIND {
Expand All @@ -104,7 +104,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for LoadInstruction<E,
// get target limb from memory word for load instructions except LW
let target_limb = match I::INST_KIND {
InsnKind::LB | InsnKind::LBU | InsnKind::LH | InsnKind::LHU => {
let target_limb = circuit_builder.create_witin(|| "target_limb")?;
let target_limb = circuit_builder.create_witin(|| "target_limb");
circuit_builder.condition_require_equal(
|| "target_limb = memory_value[low_bits[1]]",
addr_low_bits[1].clone(),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/memory/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl<E: ExtensionField, I: RIVInstruction, const N_ZEROS: usize> Instruction<E>
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;
let prev_memory_value = UInt::new(|| "prev_memory_value", circuit_builder)?;
let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value
let imm = circuit_builder.create_witin(|| "imm"); // signed 12-bit value

let memory_addr = match I::INST_KIND {
InsnKind::SW => MemAddr::construct_align4(circuit_builder),
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftLogicalInstru
circuit_builder: &mut crate::circuit_builder::CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, crate::error::ZKVMError> {
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;
let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5")?;
let rs2_low5 = circuit_builder.create_witin(|| "rs2_low5");
// pow2_rs2_low5 is unchecked because it's assignment will be constrained due it's use in lookup_pow2 below
let mut pow2_rs2_low5 = UInt::new_unchecked(|| "pow2_rs2_low5", circuit_builder)?;
// rs2 = rs2_high | rs2_low5
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/shift_imm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftImmInstructio
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
// Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification.
let imm = circuit_builder.create_witin(|| "imm")?;
let imm = circuit_builder.create_witin(|| "imm");
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rd_written = UInt::new(|| "rd_written", circuit_builder)?;

let outflow = circuit_builder.create_witin(|| "outflow")?;
let outflow = circuit_builder.create_witin(|| "outflow");
let assert_lt_config = AssertLTConfig::construct_circuit(
circuit_builder,
|| "outflow < imm",
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/slti.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<E: ExtensionField> Instruction<E> for SltiInstruction<E> {
fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<Self::InstructionConfig, ZKVMError> {
// If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0
let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?;
let imm = cb.create_witin(|| "imm")?;
let imm = cb.create_witin(|| "imm");

let max_signed_limb_expr: Expression<_> = ((1 << (UInt::<E>::LIMB_BITS - 1)) - 1).into();
let is_rs1_neg = IsLtConfig::construct_circuit(
Expand Down
Loading

0 comments on commit a74cbf3

Please sign in to comment.