Skip to content

Commit

Permalink
Fix/divu remainder larger than quotient (#335)
Browse files Browse the repository at this point in the history
### Desc.
~~This PR is trying to integrate `DivConfig` introduced in #304 into
`divu` opcode. However,`DivConfig` fails when `divisor` is zero.~~

### Findings
~~`assert_less_than` should be `None` in this case since both of `lhs:
remainder` and `rhs: divisor` are witnesses and range-checked.~~

https://github.com/scroll-tech/ceno/blob/77a250c9e7988e6d126154c34396912721394dfe/ceno_zkvm/src/gadgets/div.rs#L44-L50

----
### Update
~~In the first place, this PR is trying to fix a failure when `divisor`
is zero. During fixing this issue, realized we don't really need
`DivConfig` since there are not much duplicated code between `SRL` and
`DIVU`. Therefore, we don't really need to "fix" this issue since it
works well on `SRL`. But I need to add "remainder < divisor" constraints
in `DIVU`.~~
I was trying to use `DivConfig` in this PR, but it made more complicated
code. So, the current version is only to add `remainder < divisor`
constraints in `divu`.

### Changes
- Fixing under constraints in `less_than`
- Adding "remainder < divisor" constraint in `divu`
- Adding some tests for `SRL` and `SLL`

---------

Co-authored-by: Ming <[email protected]>
  • Loading branch information
KimiWu123 and hero78119 authored Oct 10, 2024
1 parent bac2681 commit 718def0
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 102 deletions.
19 changes: 0 additions & 19 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use std::fmt::Display;

use ff_ext::ExtensionField;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
error::ZKVMError,
expression::{Expression, Fixed, Instance, ToExpr, WitIn},
gadgets::IsLtConfig,
instructions::riscv::constants::EXIT_CODE_IDX,
structs::ROMType,
tables::InsnRecord,
Expand Down Expand Up @@ -328,22 +325,6 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.logic_u8(ROMType::Pow, 2.into(), b, c)
}

/// less_than
pub(crate) fn less_than<N, NR>(
&mut self,
name_fn: N,
lhs: Expression<E>,
rhs: Expression<E>,
assert_less_than: Option<bool>,
max_num_u16_limbs: usize,
) -> Result<IsLtConfig, ZKVMError>
where
NR: Into<String> + Display + Clone,
N: FnOnce() -> NR,
{
IsLtConfig::construct_circuit(self, name_fn, lhs, rhs, assert_less_than, max_num_u16_limbs)
}

pub(crate) fn is_equal(
&mut self,
lhs: Expression<E>,
Expand Down
6 changes: 4 additions & 2 deletions ceno_zkvm/src/chip_handler/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> MemoryChipOpera
cb.write_record(|| "write_record", write_record)?;

// assert prev_ts < current_ts
let lt_cfg = cb.less_than(
let lt_cfg = IsLtConfig::construct_circuit(
cb,
|| "prev_ts < ts",
prev_ts,
ts.clone(),
Expand Down Expand Up @@ -102,7 +103,8 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> MemoryChipOpera
cb.read_record(|| "read_record", read_record)?;
cb.write_record(|| "write_record", write_record)?;

let lt_cfg = cb.less_than(
let lt_cfg = IsLtConfig::construct_circuit(
cb,
|| "prev_ts < ts",
prev_ts,
ts.clone(),
Expand Down
6 changes: 4 additions & 2 deletions ceno_zkvm/src/chip_handler/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
cb.write_record(|| "write_record", write_record)?;

// assert prev_ts < current_ts
let lt_cfg = cb.less_than(
let lt_cfg = IsLtConfig::construct_circuit(
cb,
|| "prev_ts < ts",
prev_ts,
ts.clone(),
Expand Down Expand Up @@ -103,7 +104,8 @@ impl<'a, E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOpe
cb.read_record(|| "read_record", read_record)?;
cb.write_record(|| "write_record", write_record)?;

let lt_cfg = cb.less_than(
let lt_cfg = IsLtConfig::construct_circuit(
cb,
|| "prev_ts < ts",
prev_ts,
ts.clone(),
Expand Down
15 changes: 7 additions & 8 deletions ceno_zkvm/src/gadgets/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use ff_ext::ExtensionField;
use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
instructions::riscv::constants::{UInt, BIT_WIDTH},
instructions::riscv::constants::{UInt, UINT_LIMBS},
witness::LkMultiplicity,
Value,
};
Expand All @@ -32,18 +32,18 @@ impl<E: ExtensionField> DivConfig<E> {
remainder: &UInt<E>,
) -> Result<Self, ZKVMError> {
circuit_builder.namespace(name_fn, |cb| {
let intermediate_mul =
divisor.mul::<BIT_WIDTH, _, _>(|| "divisor_mul", cb, quotient, true)?;
let dividend = intermediate_mul.add(|| "dividend_add", cb, remainder, true)?;
let (dividend, intermediate_mul) =
divisor.mul_add(|| "divisor * outcome + r", cb, quotient, remainder, true)?;

// remainder range check
let r_lt = cb.less_than(
let r_lt = IsLtConfig::construct_circuit(
cb,
|| "remainder < divisor",
remainder.value(),
divisor.value(),
Some(true),
UInt::<E>::NUM_CELLS,
UINT_LIMBS,
)?;

Ok(Self {
dividend,
intermediate_mul,
Expand All @@ -61,7 +61,6 @@ impl<E: ExtensionField> DivConfig<E> {
remainder: &Value<'a, u32>,
) -> Result<(), ZKVMError> {
let (dividend, intermediate) = divisor.mul_add(quotient, remainder, lkm, true);

self.r_lt
.assign_instance(instance, lkm, remainder.as_u64(), divisor.as_u64())?;
self.intermediate_mul
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl IsLtConfig {
) -> Result<Self, ZKVMError> {
assert!(max_num_u16_limbs >= 1);
cb.namespace(
|| "less_than",
|| "is_lt",
|cb| {
let name = name_fn();
let (is_lt, is_lt_expr) = if let Some(lt) = assert_less_than {
Expand Down
94 changes: 60 additions & 34 deletions ceno_zkvm/src/instructions/riscv/divu.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;

use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction};
use super::{
constants::{UInt, UINT_LIMBS},
r_insn::RInstructionConfig,
RIVInstruction,
};
use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, gadgets::IsZeroConfig,
instructions::Instruction, uint::Value, witness::LkMultiplicity,
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::Expression,
gadgets::{IsLtConfig, IsZeroConfig},
instructions::Instruction,
uint::Value,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;
use std::marker::PhantomData;
Expand All @@ -19,6 +28,7 @@ pub struct ArithConfig<E: ExtensionField> {
remainder: UInt<E>,
inter_mul_value: UInt<E>,
is_zero: IsZeroConfig,
pub remainder_lt: IsLtConfig,
}

pub struct ArithInstruction<E, I>(PhantomData<(E, I)>);
Expand All @@ -36,37 +46,46 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
fn construct_circuit(cb: &mut CircuitBuilder<E>) -> Result<Self::InstructionConfig, ZKVMError> {
// outcome = dividend / divisor + remainder => dividend = divisor * outcome + r
let mut divisor = UInt::new_unchecked(|| "divisor", circuit_builder)?;
let mut outcome = UInt::new(|| "outcome", circuit_builder)?;
let r = UInt::new(|| "remainder", circuit_builder)?;

let mut divisor = UInt::new_unchecked(|| "divisor", cb)?;
let mut outcome = UInt::new(|| "outcome", cb)?;
let r = UInt::new(|| "remainder", cb)?;
let (dividend, inter_mul_value) =
divisor.mul_add(|| "dividend", circuit_builder, &mut outcome, &r, true)?;
divisor.mul_add(|| "divisor * outcome + r", cb, &mut outcome, &r, true)?;

// div by zero check
let is_zero = IsZeroConfig::construct_circuit(
circuit_builder,
|| "divisor_zero_check",
let is_zero =
IsZeroConfig::construct_circuit(cb, || "divisor_zero_check", divisor.value())?;
let outcome_value = outcome.value();
cb.condition_require_equal(
|| "outcome_is_zero",
is_zero.expr(),
outcome_value.clone(),
((1u64 << UInt::<E>::M) - 1).into(),
outcome_value,
)?;

// remainder should be less than divisor if divisor != 0.
let lt = IsLtConfig::construct_circuit(
cb,
|| "remainder < divisor?",
r.value(),
divisor.value(),
None,
UINT_LIMBS,
)?;

let outcome_value = outcome.value();
circuit_builder
.condition_require_equal(
|| "outcome_is_zero",
is_zero.expr(),
outcome_value.clone(),
((1u64 << UInt::<E>::M) - 1).into(),
outcome_value,
)
.unwrap();
// When divisor is zero, remainder is -1 implies "remainder > divisor" aka. lt.expr() == 0
// otherwise lt.expr() == 1
cb.require_equal(
|| "remainder < divisor when non-zero divisor",
is_zero.expr() + lt.expr(),
Expression::ONE,
)?;

let r_insn = RInstructionConfig::<E>::construct_circuit(
circuit_builder,
cb,
I::INST_KIND,
dividend.register_expr(),
divisor.register_expr(),
Expand All @@ -81,6 +100,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
remainder: r,
inter_mul_value,
is_zero,
remainder_lt: lt,
})
}

Expand Down Expand Up @@ -114,18 +134,18 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
.assign_limbs(instance, outcome.as_u16_limbs());

let (dividend, inter_mul_value) = divisor.mul_add(&outcome, &r, lkm, true);

config
.inter_mul_value
.assign_mul_outcome(instance, lkm, &inter_mul_value)?;

config.dividend.assign_add_outcome(instance, &dividend);

config.remainder.assign_limbs(instance, r.as_u16_limbs());

config
.is_zero
.assign_instance(instance, (rs2 as u64).into())?;
.assign_instance(instance, divisor.as_u64().into())?;
config
.remainder_lt
.assign_instance(instance, lkm, r.as_u64(), divisor.as_u64())?;

Ok(())
}
Expand All @@ -152,17 +172,22 @@ mod test {
Value,
};

fn verify(name: &'static str, dividend: Word, divisor: Word, outcome: Word) {
fn verify(name: &'static str, dividend: Word, divisor: Word, exp_outcome: Word) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);
let config = cb
.namespace(
|| format!("divu_{name}"),
|| format!("divu_({name})"),
|cb| Ok(DivUInstruction::construct_circuit(cb)),
)
.unwrap()
.unwrap();

let outcome = if divisor == 0 {
u32::MAX
} else {
dividend / divisor
};
// values assignment
let (raw_witin, _) = DivUInstruction::assign_instances(
&config,
Expand All @@ -179,8 +204,9 @@ mod test {
)
.unwrap();

let expected_rd_written =
UInt::from_const_unchecked(Value::new_unchecked(outcome).as_u16_limbs().to_vec());
let expected_rd_written = UInt::from_const_unchecked(
Value::new_unchecked(exp_outcome).as_u16_limbs().to_vec(),
);

config
.outcome
Expand All @@ -206,8 +232,8 @@ mod test {
verify("u32::MAX", u32::MAX, u32::MAX, 1);
verify("div u32::MAX", 3, u32::MAX, 0);
verify("u32::MAX div by 2", u32::MAX, 2, u32::MAX / 2);
verify("mul with carries", 1202729773, 171818539, 7);
verify("div by zero", 10, 0, u32::MAX);
verify("mul carry", 1202729773, 171818539, 7);
}

#[test]
Expand Down
Loading

0 comments on commit 718def0

Please sign in to comment.