diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 1362ce74d..8a0e58f37 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -691,9 +691,20 @@ mod tests { use itertools::Itertools; type E = GoldilocksExt2; // 18446744069414584321 - const POW_OF_C: u64 = 2_usize.pow(16u32) as u64; + // const POW_OF_C: u64 = 2_usize.pow(16u32) as u64; + + // fn test_data() -> Vec> { + // let d = vec![ + // [256, 256, 0, 0].into(), + // [257, 256, 0, 0].into(), + // [256, 257, 2, 1].into(), + // [1, 2, 1].into(), + // ]; + // d + // } + #[test] - fn test_mul_no_carries() { + fn test_mul64_16_no_carries() { // a = 1 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 2 + 3 * 2^16 + 1 * 2^32 = 4,295,163,906 @@ -706,7 +717,7 @@ mod tests { } #[test] - fn test_mul_w_carry() { + fn test_mul64_16_w_carry() { // a = 256 + 1 * 2^16 // b = 257 + 1 * 2^16 // c = 256 + 514 * 2^16 + 1 * 2^32 = 4,328,653,056 @@ -719,7 +730,7 @@ mod tests { } #[test] - fn test_mul_w_carries() { + fn test_mul64_16_w_carries() { // a = 256 + 256 * 2^16 = 16,777,472 // b = 257 + 256 * 2^16 = 16,777,473 // c = 256 + 257 * 2^16 + 2 * 2^32 + 1 * 2^48 = 281,483,583,488,256 @@ -734,11 +745,50 @@ mod tests { verify::<64, 16, E>(witness_values, false); } + #[test] + fn test_mul64_8_w_carries() { + // a = 256 + // b = 257 + // c = 254 + 1 * 2^16 = 510 + let wit_a = vec![255, 0, 0, 0, 0, 0, 0, 0]; + let wit_b = vec![2, 0, 0, 0, 0, 0, 0, 0]; + let wit_c = vec![254, 1, 0, 0, 0, 0, 0, 0]; + let wit_carries = vec![1, 0, 0, 0, 0, 0, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<64, 8, E>(witness_values, false); + } + + #[test] + fn test_mul32_16_w_carries() { + // a = 256 + // b = 257 + // c = 256 + 1 * 2^16 = 65,792 + let wit_a = vec![256, 0]; + let wit_b = vec![257, 0]; + let wit_c = vec![256, 1]; + let wit_carries = vec![1, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<32, 16, E>(witness_values, false); + } + + #[test] + fn test_mul32_5_w_carries() { + // a = 31 + // b = 2 + // c = 30 + 1 * 2^8 = 62 + let wit_a = vec![31, 0, 0, 0, 0, 0, 0]; + let wit_b = vec![2, 0, 0, 0, 0, 0, 0]; + let wit_c = vec![30, 1, 0, 0, 0, 0, 0]; + let wit_carries = vec![1, 0, 0, 0, 0, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<32, 5, E>(witness_values, false); + } + fn verify( witness_values: Vec, overflow: bool, ) { - let mut cs = ConstraintSystem::new(|| "test"); + let mut cs = ConstraintSystem::new(|| "test_mul"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); @@ -748,6 +798,7 @@ mod tests { .mul(|| "uint_c", &mut circuit_builder, &mut uint_b, false) .unwrap(); + let POW_OF_C: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; let single_wit_size = UInt::::NUM_CELLS; let wit_end_idx = if overflow { 4 * single_wit_size @@ -767,13 +818,15 @@ mod tests { result[idx] += a_limb * b_limb; } }); + }); - // take care carries + // take care carries + result.iter_mut().enumerate().for_each(|(i, ret)| { if !overflow && c_carries.get(i).is_some() { - result[i] -= c_carries[i] * POW_OF_C; + *ret -= c_carries[i] * POW_OF_C; } if i != 0 { - result[i] += c_carries[i - 1]; + *ret += c_carries[i - 1]; } });