Skip to content

Commit

Permalink
wip multiplication in new way
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Oct 2, 2024
1 parent ad364d2 commit 35ceb80
Show file tree
Hide file tree
Showing 23 changed files with 372 additions and 293 deletions.
8 changes: 4 additions & 4 deletions ceno_zkvm/src/chip_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
prev_ts: Expression<E>,
ts: Expression<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError>;
) -> Result<(Expression<E>, IsLtConfig), ZKVMError>;

#[allow(clippy::too_many_arguments)]
fn register_write(
Expand All @@ -42,7 +42,7 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
ts: Expression<E>,
prev_values: RegisterExpr<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError>;
) -> Result<(Expression<E>, IsLtConfig), ZKVMError>;
}

/// The common representation of a memory value.
Expand All @@ -58,7 +58,7 @@ pub trait MemoryChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce()
prev_ts: Expression<E>,
ts: Expression<E>,
value: crate::chip_handler::MemoryExpr<E>,
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError>;
) -> Result<(Expression<E>, IsLtConfig), ZKVMError>;

#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
Expand All @@ -70,5 +70,5 @@ pub trait MemoryChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce()
ts: Expression<E>,
prev_values: crate::chip_handler::MemoryExpr<E>,
value: crate::chip_handler::MemoryExpr<E>,
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError>;
) -> Result<(Expression<E>, IsLtConfig), ZKVMError>;
}
9 changes: 5 additions & 4 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
16 => self.assert_u16(name_fn, expr),
8 => self.assert_byte(name_fn, expr),
5 => self.assert_u5(name_fn, expr),
_ => panic!("Unsupported bit range"),
1 => self.assert_bit(name_fn, expr),
c => panic!("Unsupported bit range {c}"),
}
}

Expand Down Expand Up @@ -315,18 +316,18 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
}

/// less_than
pub(crate) fn less_than<N, NR, const N_LIMBS: usize>(
pub(crate) fn less_than<N, NR, const MAX_U16_LIMB: usize>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
assert_less_than: Option<bool>,
) -> Result<IsLtConfig<N_LIMBS>, ZKVMError>
) -> Result<IsLtConfig, ZKVMError>
where
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
{
IsLtConfig::construct_circuit(self, name_fn, lhs, rhs, assert_less_than)
IsLtConfig::construct_circuit(self, name_fn, lhs, rhs, assert_less_than, MAX_U16_LIMB)
}

pub(crate) fn is_equal(
Expand Down
13 changes: 9 additions & 4 deletions ceno_zkvm/src/chip_handler/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> MemoryChipOpera
prev_ts: Expression<E>,
ts: Expression<E>,
value: MemoryExpr<E>,
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError> {
) -> Result<(Expression<E>, IsLtConfig), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down Expand Up @@ -50,7 +50,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> MemoryChipOpera
cb.write_record(|| "write_record", write_record)?;

// assert prev_ts < current_ts
let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;
let lt_cfg = cb.less_than::<_, _, UINT_LIMBS>(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;

let next_ts = ts + 1.into();

Expand All @@ -67,7 +67,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> MemoryChipOpera
ts: Expression<E>,
prev_values: MemoryExpr<E>,
value: MemoryExpr<E>,
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError> {
) -> Result<(Expression<E>, IsLtConfig), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down Expand Up @@ -96,7 +96,12 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> MemoryChipOpera
cb.read_record(|| "read_record", read_record)?;
cb.write_record(|| "write_record", write_record)?;

let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;
let lt_cfg = cb.less_than::<_, _, UINT_LIMBS>(
|| "prev_ts < ts",
prev_ts,
ts.clone(),
Some(true),
)?;

let next_ts = ts + 1.into();

Expand Down
18 changes: 14 additions & 4 deletions ceno_zkvm/src/chip_handler/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
prev_ts: Expression<E>,
ts: Expression<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError> {
) -> Result<(Expression<E>, IsLtConfig), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down Expand Up @@ -51,7 +51,12 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
cb.write_record(|| "write_record", write_record)?;

// assert prev_ts < current_ts
let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;
let lt_cfg = cb.less_than::<_, _, UINT_LIMBS>(
|| "prev_ts < ts",
prev_ts,
ts.clone(),
Some(true),
)?;

let next_ts = ts + 1.into();

Expand All @@ -67,7 +72,7 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
ts: Expression<E>,
prev_values: RegisterExpr<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, IsLtConfig<UINT_LIMBS>), ZKVMError> {
) -> Result<(Expression<E>, IsLtConfig), ZKVMError> {
self.namespace(name_fn, |cb| {
// READ (a, v, t)
let read_record = cb.rlc_chip_record(
Expand Down Expand Up @@ -96,7 +101,12 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
cb.read_record(|| "read_record", read_record)?;
cb.write_record(|| "write_record", write_record)?;

let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?;
let lt_cfg = cb.less_than::<_, _, UINT_LIMBS>(
|| "prev_ts < ts",
prev_ts,
ts.clone(),
Some(true),
)?;

let next_ts = ts + 1.into();

Expand Down
33 changes: 23 additions & 10 deletions ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ use crate::{
witness::LkMultiplicity,
};

#[derive(Debug)]
pub struct IsLtConfig<const N_U16: usize> {
#[derive(Debug, Clone)]
pub struct IsLtConfig {
pub is_lt: Option<WitIn>,
pub diff: [WitIn; N_U16],
pub diff: Vec<WitIn>,
pub max_num_u16_limbs: usize,
}

impl<const N_U16: usize> IsLtConfig<N_U16> {
impl IsLtConfig {
pub fn expr<E: ExtensionField>(&self) -> Expression<E> {
self.is_lt.unwrap().expr()
}
Expand All @@ -34,8 +35,9 @@ impl<const N_U16: usize> IsLtConfig<N_U16> {
lhs: Expression<E>,
rhs: Expression<E>,
assert_less_than: Option<bool>,
max_num_u16_limbs: usize,
) -> Result<Self, ZKVMError> {
assert!(N_U16 >= 1);
assert!(max_num_u16_limbs >= 1);
cb.namespace(
|| "less_than",
|cb| {
Expand Down Expand Up @@ -66,7 +68,7 @@ impl<const N_U16: usize> IsLtConfig<N_U16> {
)
};

let diff = (0..N_U16)
let diff = (0..max_num_u16_limbs)
.map(|i| witin_u16(format!("diff_{i}")))
.collect::<Result<Vec<WitIn>, _>>()?;

Expand All @@ -79,18 +81,30 @@ impl<const N_U16: usize> IsLtConfig<N_U16> {
.reduce(|a, b| a + b)
.expect("reduce error");

let range = (1 << (N_U16 * u16::BITS as usize)).into();
let range = (1 << (max_num_u16_limbs * u16::BITS as usize)).into();

cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?;

Ok(IsLtConfig {
is_lt,
diff: diff.try_into().unwrap(),
diff,
max_num_u16_limbs,
})
},
)
}

pub fn cal_diff(is_lt: bool, max_num_u16_limbs: usize, lhs: u64, rhs: u64) -> u64 {
println!("max_num_u16_limbs {max_num_u16_limbs}");
let diff = if is_lt {
1u64 << (u16::BITS as usize * max_num_u16_limbs)
} else {
0
} + lhs
- rhs;
diff
}

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
Expand All @@ -106,8 +120,7 @@ impl<const N_U16: usize> IsLtConfig<N_U16> {
// assert is_lt == true
true
};

let diff = if is_lt { 1u64 << u32::BITS } else { 0 } + lhs - rhs;
let diff = Self::cal_diff(is_lt, self.max_num_u16_limbs, lhs, rhs);
self.diff.iter().enumerate().for_each(|(i, wit)| {
// extract the 16 bit limb from diff and assign to instance
let val = (diff >> (i * u16::BITS as usize)) & 0xffff;
Expand Down
45 changes: 18 additions & 27 deletions ceno_zkvm/src/instructions/riscv/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::marker::PhantomData;

use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;
use itertools::Itertools;

use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction};
use crate::{
Expand Down Expand Up @@ -115,39 +114,27 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
let rs2_read = Value::new_unchecked(step.rs2().unwrap().value);
config
.rs2_read
.assign_limbs(instance, rs2_read.u16_fields());
.assign_limbs(instance, rs2_read.as_u16_limbs());

match I::INST_KIND {
InsnKind::ADD => {
// rs1_read + rs2_read = rd_written
let rs1_read = Value::new_unchecked(step.rs1().unwrap().value);
config
.rs1_read
.assign_limbs(instance, rs1_read.u16_fields());
.assign_limbs(instance, rs1_read.as_u16_limbs());
let (_, outcome_carries) = rs1_read.add(&rs2_read, lk_multiplicity, true);
config.rd_written.assign_carries(
instance,
outcome_carries
.into_iter()
.map(|carry| E::BaseField::from(carry as u64))
.collect_vec(),
);
config.rd_written.assign_carries(instance, &outcome_carries);
}

InsnKind::SUB => {
// rs1_read = rd_written + rs2_read
let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity);
config
.rd_written
.assign_limbs(instance, rd_written.u16_fields());
.assign_limbs(instance, rd_written.as_u16_limbs());
let (_, addend_0_carries) = rs2_read.add(&rd_written, lk_multiplicity, true);
config.rs1_read.assign_carries(
instance,
addend_0_carries
.into_iter()
.map(|carry| E::BaseField::from(carry as u64))
.collect_vec(),
);
config.rs1_read.assign_carries(instance, &addend_0_carries);
}

InsnKind::MUL => {
Expand All @@ -157,20 +144,24 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E

config
.rs1_read
.assign_limbs(instance, rs1_read.u16_fields());
.assign_limbs(instance, rs1_read.as_u16_limbs());

let (_, carries, max_carry) = rs1_read.mul(&rs2_read, lk_multiplicity, true);

let (_, carries) = rs1_read.mul(&rs2_read, lk_multiplicity, true);
config
.rd_written
.assign_limbs(instance, rd_written.as_u16_limbs());

config
.rd_written
.assign_limbs(instance, rd_written.u16_fields());
config.rd_written.assign_carries(
.assign_limbs(instance, rd_written.as_u16_limbs());
config.rd_written.assign_carries(instance, &carries);
config.rd_written.assign_carries_auxiliary(
instance,
carries
.into_iter()
.map(|carry| E::BaseField::from(carry as u64))
.collect_vec(),
);
lk_multiplicity,
&carries,
max_carry,
)?;
}

_ => unreachable!("Unsupported instruction kind"),
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions/riscv/b_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;

use super::constants::{PC_STEP_SIZE, UINT_LIMBS};
use super::constants::PC_STEP_SIZE;
use crate::{
chip_handler::{
GlobalStateRegisterMachineChipOperations, RegisterChipOperations, RegisterExpr,
Expand Down Expand Up @@ -45,8 +45,8 @@ pub struct BInstructionConfig {
imm: WitIn,
prev_rs1_ts: WitIn,
prev_rs2_ts: WitIn,
lt_rs1_cfg: IsLtConfig<UINT_LIMBS>,
lt_rs2_cfg: IsLtConfig<UINT_LIMBS>,
lt_rs1_cfg: IsLtConfig,
lt_rs2_cfg: IsLtConfig,
}

impl BInstructionConfig {
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BeqCircuit<E, I> {
let rs1_read = step.rs1().unwrap().value;
config
.rs1_read
.assign_limbs(instance, Value::new_unchecked(rs1_read).u16_fields());
.assign_limbs(instance, Value::new_unchecked(rs1_read).as_u16_limbs());

let rs2_read = step.rs2().unwrap().value;
config
.rs2_read
.assign_limbs(instance, Value::new_unchecked(rs2_read).u16_fields());
.assign_limbs(instance, Value::new_unchecked(rs2_read).as_u16_limbs());

config.equal.assign_instance(
instance,
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/branch/blt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BltCircuit<I> {
) -> Result<(), ZKVMError> {
let rs1 = Value::new_unchecked(step.rs1().unwrap().value);
let rs2 = Value::new_unchecked(step.rs2().unwrap().value);
config.read_rs1.assign_limbs(instance, rs1.u16_fields());
config.read_rs2.assign_limbs(instance, rs2.u16_fields());
config.read_rs1.assign_limbs(instance, rs1.as_u16_limbs());
config.read_rs2.assign_limbs(instance, rs2.as_u16_limbs());
config.is_lt.assign_instance::<E>(
instance,
lk_multiplicity,
Expand Down
7 changes: 4 additions & 3 deletions ceno_zkvm/src/instructions/riscv/branch/bltu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct InstructionConfig<E: ExtensionField> {
pub b_insn: BInstructionConfig,
pub read_rs1: UInt<E>,
pub read_rs2: UInt<E>,
pub is_lt: IsLtConfig<UINT_LIMBS>,
pub is_lt: IsLtConfig,
}

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BltuCircuit<I> {
Expand All @@ -48,6 +48,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BltuCircuit<I> {
read_rs1.value(),
read_rs2.value(),
None,
UINT_LIMBS,
)?;

let branch_taken_bit = match I::INST_KIND {
Expand Down Expand Up @@ -81,8 +82,8 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BltuCircuit<I> {
) -> Result<(), ZKVMError> {
let rs1 = Value::new_unchecked(step.rs1().unwrap().value);
let rs2 = Value::new_unchecked(step.rs2().unwrap().value);
config.read_rs1.assign_limbs(instance, rs1.u16_fields());
config.read_rs2.assign_limbs(instance, rs2.u16_fields());
config.read_rs1.assign_limbs(instance, rs1.as_u16_limbs());
config.read_rs2.assign_limbs(instance, rs2.as_u16_limbs());
config.is_lt.assign_instance(
instance,
lk_multiplicity,
Expand Down
Loading

0 comments on commit 35ceb80

Please sign in to comment.