Skip to content

Commit

Permalink
test: add M/C!=4 cases
Browse files Browse the repository at this point in the history
  • Loading branch information
KimiWu123 committed Sep 9, 2024
1 parent 8dc4523 commit bad8dce
Showing 1 changed file with 61 additions and 8 deletions.
69 changes: 61 additions & 8 deletions ceno_zkvm/src/uint/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,9 +691,20 @@ mod tests {
use itertools::Itertools;

type E = GoldilocksExt2; // 18446744069414584321
const POW_OF_C: u64 = 2_usize.pow(16u32) as u64;
// const POW_OF_C: u64 = 2_usize.pow(16u32) as u64;

// fn test_data() -> Vec<Vec<u64>> {
// let d = vec![
// [256, 256, 0, 0].into(),
// [257, 256, 0, 0].into(),
// [256, 257, 2, 1].into(),
// [1, 2, 1].into(),
// ];
// d
// }

#[test]
fn test_mul_no_carries() {
fn test_mul64_16_no_carries() {
// a = 1 + 1 * 2^16
// b = 2 + 1 * 2^16
// c = 2 + 3 * 2^16 + 1 * 2^32 = 4,295,163,906
Expand All @@ -706,7 +717,7 @@ mod tests {
}

#[test]
fn test_mul_w_carry() {
fn test_mul64_16_w_carry() {
// a = 256 + 1 * 2^16
// b = 257 + 1 * 2^16
// c = 256 + 514 * 2^16 + 1 * 2^32 = 4,328,653,056
Expand All @@ -719,7 +730,7 @@ mod tests {
}

#[test]
fn test_mul_w_carries() {
fn test_mul64_16_w_carries() {
// a = 256 + 256 * 2^16 = 16,777,472
// b = 257 + 256 * 2^16 = 16,777,473
// c = 256 + 257 * 2^16 + 2 * 2^32 + 1 * 2^48 = 281,483,583,488,256
Expand All @@ -734,11 +745,50 @@ mod tests {
verify::<64, 16, E>(witness_values, false);
}

#[test]
fn test_mul64_8_w_carries() {
// a = 256
// b = 257
// c = 254 + 1 * 2^16 = 510
let wit_a = vec![255, 0, 0, 0, 0, 0, 0, 0];
let wit_b = vec![2, 0, 0, 0, 0, 0, 0, 0];
let wit_c = vec![254, 1, 0, 0, 0, 0, 0, 0];
let wit_carries = vec![1, 0, 0, 0, 0, 0, 0];
let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat();
verify::<64, 8, E>(witness_values, false);
}

#[test]
fn test_mul32_16_w_carries() {
// a = 256
// b = 257
// c = 256 + 1 * 2^16 = 65,792
let wit_a = vec![256, 0];
let wit_b = vec![257, 0];
let wit_c = vec![256, 1];
let wit_carries = vec![1, 0];
let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat();
verify::<32, 16, E>(witness_values, false);
}

#[test]
fn test_mul32_5_w_carries() {
// a = 31
// b = 2
// c = 30 + 1 * 2^8 = 62
let wit_a = vec![31, 0, 0, 0, 0, 0, 0];
let wit_b = vec![2, 0, 0, 0, 0, 0, 0];
let wit_c = vec![30, 1, 0, 0, 0, 0, 0];
let wit_carries = vec![1, 0, 0, 0, 0, 0];
let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat();
verify::<32, 5, E>(witness_values, false);
}

fn verify<const M: usize, const C: usize, E: ExtensionField>(
witness_values: Vec<u64>,
overflow: bool,
) {
let mut cs = ConstraintSystem::new(|| "test");
let mut cs = ConstraintSystem::new(|| "test_mul");
let mut circuit_builder = CircuitBuilder::<E>::new(&mut cs);
let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec();

Expand All @@ -748,6 +798,7 @@ mod tests {
.mul(|| "uint_c", &mut circuit_builder, &mut uint_b, false)
.unwrap();

let POW_OF_C: u64 = 2_usize.pow(UInt::<M, C, E>::MAX_CELL_BIT_WIDTH as u32) as u64;
let single_wit_size = UInt::<M, C, E>::NUM_CELLS;
let wit_end_idx = if overflow {
4 * single_wit_size
Expand All @@ -767,13 +818,15 @@ mod tests {
result[idx] += a_limb * b_limb;
}
});
});

// take care carries
// take care carries
result.iter_mut().enumerate().for_each(|(i, ret)| {
if !overflow && c_carries.get(i).is_some() {
result[i] -= c_carries[i] * POW_OF_C;
*ret -= c_carries[i] * POW_OF_C;
}
if i != 0 {
result[i] += c_carries[i - 1];
*ret += c_carries[i - 1];
}
});

Expand Down

0 comments on commit bad8dce

Please sign in to comment.