From 340ef6c254c6e081ea69984953a9b671feb5ff3b Mon Sep 17 00:00:00 2001 From: KimiWu Date: Mon, 9 Sep 2024 14:23:14 +0800 Subject: [PATCH 1/7] update mul for arbitrary limb size --- ceno_zkvm/src/uint/arithmetic.rs | 77 ++++++++++++++++---------------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index caeb59273..a6bfd91a8 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -133,44 +133,45 @@ impl UInt { // result check let c_expr = c.expr(); - let c_carries = c.carries.as_ref().unwrap(); - - // TODO #174 - // a_expr[0] * b_expr[0] - c_carry[0] * 2^C = c_expr[0] - circuit_builder.require_equal( - || "c_expr[0]", - a_expr[0].clone() * b_expr[0].clone() - c_carries[0].expr() * Self::POW_OF_C.into(), - c_expr[0].clone(), - )?; - // a_expr[0] * b_expr[1] + a_expr[1] * b_expr[0] - c_carry[1] * 2^C + c_carry[0] = c_expr[1] - circuit_builder.require_equal( - || "c_expr[1]", - a_expr[0].clone() * b_expr[0].clone() - c_carries[1].expr() * Self::POW_OF_C.into() - + c_carries[0].expr(), - c_expr[1].clone(), - )?; - // a_expr[0] * b_expr[2] + a_expr[1] * b_expr[1] + a_expr[2] * b_expr[0] - - // c_carry[2] * 2^C + c_carry[1] = c_expr[2] - circuit_builder.require_equal( - || "c_expr[2]", - a_expr[0].clone() * b_expr[2].clone() - + a_expr[1].clone() * b_expr[1].clone() - + a_expr[2].clone() * b_expr[0].clone() - - c_carries[2].expr() * Self::POW_OF_C.into() - + c_carries[1].expr(), - c_expr[2].clone(), - )?; - // a_expr[0] * b_expr[3] + a_expr[1] * b_expr[2] + a_expr[2] * b_expr[1] + - // a_expr[3] * b_expr[0] - c_carry[3] * 2^C + c_carry[2] = c_expr[3] - let mut target = a_expr[0].clone() * b_expr[3].clone() - + a_expr[1].clone() * b_expr[2].clone() - + a_expr[2].clone() * b_expr[1].clone() - + a_expr[3].clone() * b_expr[0].clone() - + c_carries[2].expr(); - if let Some(overflow) = c_carries.get(3) { - target = target - overflow.expr() * Self::POW_OF_C.into(); - } - circuit_builder.require_equal(|| "c_expr[3]", target, c_expr[3].clone())?; + let carries = c.carries.as_ref().unwrap(); + + // compute the result + let mut result_c: Vec> = Vec::>::with_capacity(Self::NUM_CELLS); + a_expr.iter().enumerate().for_each(|(i, a)| { + b_expr.iter().enumerate().for_each(|(j, b)| { + let idx = i + j; + if idx < Self::NUM_CELLS { + if result_c.get(idx).is_none() { + result_c.push(a.clone() * b.clone()); + } else { + result_c[idx] = result_c[idx].clone() + a.clone() * b.clone(); + } + } + }); + + // take care carries + let carry = if i > 0 { carries.get(i - 1) } else { None }; + let next_carry = carries.get(i); + if carry.is_some() { + result_c[i] = result_c[i].clone() + carry.unwrap().expr(); + } + if next_carry.is_some() { + result_c[i] = + result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C.into(); + } + }); + + // result check + c_expr + .iter() + .zip(result_c) + .enumerate() + .for_each(|(i, (target, result))| { + circuit_builder + .require_equal(|| format!("c_expr{i}"), target.clone(), result) + .unwrap(); + }); + Ok(c) } From 1ded491a9c032ea4b725572096cf304b068ad6bc Mon Sep 17 00:00:00 2001 From: KimiWu Date: Mon, 9 Sep 2024 15:25:18 +0800 Subject: [PATCH 2/7] test: support abritrary len of limbs --- ceno_zkvm/src/uint/arithmetic.rs | 122 ++++++++++++++++--------------- 1 file changed, 63 insertions(+), 59 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index a6bfd91a8..1362ce74d 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -455,15 +455,9 @@ mod tests { eval_by_expr(&witness_values, &challenges, &c.expr()[3]), E::ZERO ); - // overflow - assert_eq!( - eval_by_expr( - &witness_values, - &challenges, - &c.carries.unwrap().last().unwrap().expr() - ), - E::ZERO - ); + + // non-overflow case, the len of carries should be (NUM_CELLS - 1) + assert_eq!(c.carries.unwrap().len(), 3) } #[test] @@ -506,15 +500,8 @@ mod tests { eval_by_expr(&witness_values, &challenges, &c.expr()[3]), E::ZERO ); - // overflow - assert_eq!( - eval_by_expr( - &witness_values, - &challenges, - &c.carries.unwrap().last().unwrap().expr() - ), - E::ZERO - ); + // non-overflow case, the len of carries should be (NUM_CELLS - 1) + assert_eq!(c.carries.unwrap().len(), 3) } #[test] @@ -556,15 +543,8 @@ mod tests { eval_by_expr(&witness_values, &challenges, &c.expr()[3]), E::ZERO ); - // overflow - assert_eq!( - eval_by_expr( - &witness_values, - &challenges, - &c.carries.unwrap().last().unwrap().expr() - ), - E::ZERO - ); + // non-overflow case, the len of carries should be (NUM_CELLS - 1) + assert_eq!(c.carries.unwrap().len(), 3) } #[test] @@ -705,6 +685,7 @@ mod tests { scheme::utils::eval_by_expr, uint::UInt, }; + use ark_std::iterable::Iterable; use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::Itertools; @@ -719,9 +700,9 @@ mod tests { let wit_a = vec![1, 1, 0, 0]; let wit_b = vec![2, 1, 0, 0]; let wit_c = vec![2, 3, 1, 0]; - let wit_carries = vec![0, 0, 0, 0]; + let wit_carries = vec![0, 0, 0]; let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); - verify::(witness_values, false); + verify::<64, 16, E>(witness_values, false); } #[test] @@ -732,9 +713,9 @@ mod tests { let wit_a = vec![256, 1, 0, 0]; let wit_b = vec![257, 1, 0, 0]; let wit_c = vec![256, 514, 1, 0]; - let wit_carries = vec![1, 0, 0, 0]; + let wit_carries = vec![1, 0, 0]; let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); - verify::(witness_values, false); + verify::<64, 16, E>(witness_values, false); } #[test] @@ -748,51 +729,74 @@ mod tests { // ==> [256 + 1 * (2^16), 256 + 2 * (2^16), 0 + 1 * (2^16), 0] // so we get wit_c = [256, 256, 0, 0] and carries = [1, 2, 1, 0] let wit_c = vec![256, 257, 2, 1]; - let wit_carries = vec![1, 2, 1, 0]; + let wit_carries = vec![1, 2, 1]; let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); - verify::(witness_values, true); + verify::<64, 16, E>(witness_values, false); } - fn verify(witness_values: Vec, overflow: bool) { + fn verify( + witness_values: Vec, + overflow: bool, + ) { let mut cs = ConstraintSystem::new(|| "test"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let mut uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut circuit_builder).unwrap(); - let mut uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut circuit_builder).unwrap(); + let mut uint_a = UInt::::new(|| "uint_a", &mut circuit_builder).unwrap(); + let mut uint_b = UInt::::new(|| "uint_b", &mut circuit_builder).unwrap(); let uint_c = uint_a .mul(|| "uint_c", &mut circuit_builder, &mut uint_b, false) .unwrap(); - let a = &witness_values[0..4]; - let b = &witness_values[4..8]; - let c_carries = &witness_values[12..16]; + let single_wit_size = UInt::::NUM_CELLS; + let wit_end_idx = if overflow { + 4 * single_wit_size + } else { + 4 * single_wit_size - 1 + }; + let a = &witness_values[0..single_wit_size]; + let b = &witness_values[single_wit_size..2 * single_wit_size]; + let c_carries = &witness_values[3 * single_wit_size..wit_end_idx]; // limbs cal. - let t0 = a[0] * b[0] - c_carries[0] * POW_OF_C; - let t1 = a[0] * b[1] + a[1] * b[0] - c_carries[1] * POW_OF_C + c_carries[0]; - let t2 = - a[0] * b[2] + a[1] * b[1] + a[2] * b[0] - c_carries[2] * POW_OF_C + c_carries[1]; - let t3 = a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0] - - c_carries[3] * POW_OF_C - + c_carries[2]; + let mut result = vec![0u64; single_wit_size]; + a.iter().enumerate().for_each(|(i, a_limb)| { + b.iter().enumerate().for_each(|(j, b_limb)| { + let idx = i + j; + if idx < single_wit_size { + result[idx] += a_limb * b_limb; + } + }); + + // take care carries + if !overflow && c_carries.get(i).is_some() { + result[i] -= c_carries[i] * POW_OF_C; + } + if i != 0 { + result[i] += c_carries[i - 1]; + } + }); // verify let c_expr = uint_c.expr(); - let w: Vec = witness_values.iter().map(|&a| a.into()).collect_vec(); - assert_eq!(eval_by_expr(&w, &challenges, &c_expr[0]), E::from(t0)); - assert_eq!(eval_by_expr(&w, &challenges, &c_expr[1]), E::from(t1)); - assert_eq!(eval_by_expr(&w, &challenges, &c_expr[2]), E::from(t2)); - assert_eq!(eval_by_expr(&w, &challenges, &c_expr[3]), E::from(t3)); + let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); + c_expr.iter().zip(result).for_each(|(c, ret)| { + assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); + }); // overflow - assert_eq!( - eval_by_expr( - &w, - &challenges, - &uint_c.carries.unwrap().last().unwrap().expr() - ), - if overflow { E::ONE } else { E::ZERO } - ); + if overflow { + assert_eq!( + eval_by_expr( + &wit, + &challenges, + &uint_c.carries.unwrap().last().unwrap().expr() + ), + E::ONE + ); + } else { + // non-overflow case, the len of carries should be (NUM_CELLS - 1) + assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) + } } } From 8c9f9508e3361f1a031613a656f41863785d587d Mon Sep 17 00:00:00 2001 From: KimiWu Date: Mon, 9 Sep 2024 17:56:16 +0800 Subject: [PATCH 3/7] test: add M/C!=4 cases for mul --- ceno_zkvm/src/uint/arithmetic.rs | 69 ++++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 1362ce74d..8a0e58f37 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -691,9 +691,20 @@ mod tests { use itertools::Itertools; type E = GoldilocksExt2; // 18446744069414584321 - const POW_OF_C: u64 = 2_usize.pow(16u32) as u64; + // const POW_OF_C: u64 = 2_usize.pow(16u32) as u64; + + // fn test_data() -> Vec> { + // let d = vec![ + // [256, 256, 0, 0].into(), + // [257, 256, 0, 0].into(), + // [256, 257, 2, 1].into(), + // [1, 2, 1].into(), + // ]; + // d + // } + #[test] - fn test_mul_no_carries() { + fn test_mul64_16_no_carries() { // a = 1 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 2 + 3 * 2^16 + 1 * 2^32 = 4,295,163,906 @@ -706,7 +717,7 @@ mod tests { } #[test] - fn test_mul_w_carry() { + fn test_mul64_16_w_carry() { // a = 256 + 1 * 2^16 // b = 257 + 1 * 2^16 // c = 256 + 514 * 2^16 + 1 * 2^32 = 4,328,653,056 @@ -719,7 +730,7 @@ mod tests { } #[test] - fn test_mul_w_carries() { + fn test_mul64_16_w_carries() { // a = 256 + 256 * 2^16 = 16,777,472 // b = 257 + 256 * 2^16 = 16,777,473 // c = 256 + 257 * 2^16 + 2 * 2^32 + 1 * 2^48 = 281,483,583,488,256 @@ -734,11 +745,50 @@ mod tests { verify::<64, 16, E>(witness_values, false); } + #[test] + fn test_mul64_8_w_carries() { + // a = 256 + // b = 257 + // c = 254 + 1 * 2^16 = 510 + let wit_a = vec![255, 0, 0, 0, 0, 0, 0, 0]; + let wit_b = vec![2, 0, 0, 0, 0, 0, 0, 0]; + let wit_c = vec![254, 1, 0, 0, 0, 0, 0, 0]; + let wit_carries = vec![1, 0, 0, 0, 0, 0, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<64, 8, E>(witness_values, false); + } + + #[test] + fn test_mul32_16_w_carries() { + // a = 256 + // b = 257 + // c = 256 + 1 * 2^16 = 65,792 + let wit_a = vec![256, 0]; + let wit_b = vec![257, 0]; + let wit_c = vec![256, 1]; + let wit_carries = vec![1, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<32, 16, E>(witness_values, false); + } + + #[test] + fn test_mul32_5_w_carries() { + // a = 31 + // b = 2 + // c = 30 + 1 * 2^8 = 62 + let wit_a = vec![31, 0, 0, 0, 0, 0, 0]; + let wit_b = vec![2, 0, 0, 0, 0, 0, 0]; + let wit_c = vec![30, 1, 0, 0, 0, 0, 0]; + let wit_carries = vec![1, 0, 0, 0, 0, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<32, 5, E>(witness_values, false); + } + fn verify( witness_values: Vec, overflow: bool, ) { - let mut cs = ConstraintSystem::new(|| "test"); + let mut cs = ConstraintSystem::new(|| "test_mul"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); @@ -748,6 +798,7 @@ mod tests { .mul(|| "uint_c", &mut circuit_builder, &mut uint_b, false) .unwrap(); + let POW_OF_C: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; let single_wit_size = UInt::::NUM_CELLS; let wit_end_idx = if overflow { 4 * single_wit_size @@ -767,13 +818,15 @@ mod tests { result[idx] += a_limb * b_limb; } }); + }); - // take care carries + // take care carries + result.iter_mut().enumerate().for_each(|(i, ret)| { if !overflow && c_carries.get(i).is_some() { - result[i] -= c_carries[i] * POW_OF_C; + *ret -= c_carries[i] * POW_OF_C; } if i != 0 { - result[i] += c_carries[i - 1]; + *ret += c_carries[i - 1]; } }); From 4a0751383d0a284d0a828693dc41aefe9d4dc364 Mon Sep 17 00:00:00 2001 From: KimiWu Date: Tue, 10 Sep 2024 13:40:08 +0800 Subject: [PATCH 4/7] test: refactor 'add' verify func --- ceno_zkvm/src/uint/arithmetic.rs | 307 +++++++++++-------------------- 1 file changed, 103 insertions(+), 204 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 8a0e58f37..b3e313687 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -412,269 +412,168 @@ mod tests { uint::UInt, }; use ff::Field; + use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::Itertools; type E = GoldilocksExt2; #[test] fn test_add_no_carries() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - // a = 1 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 3 + 2 * 2^16 with 0 carries let a = vec![1, 1, 0, 0]; let b = vec![2, 1, 0, 0]; let carries = vec![0; 3]; // no overflow - let witness_values = [a, b, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = UInt::<64, 16, E>::new(|| "b", &mut circuit_builder).unwrap(); - let c = a.add(|| "c", &mut circuit_builder, &b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::from(3) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::from(2) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); - - // non-overflow case, the len of carries should be (NUM_CELLS - 1) - assert_eq!(c.carries.unwrap().len(), 3) + let witness_values = [a, b, carries].concat(); + verify::<64, 16, E>(witness_values, None, false); } #[test] fn test_add_w_carry() { - type E = GoldilocksExt2; - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - // a = 65535 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 1 + 3 * 2^16 with carries [1, 0, 0, 0] let a = vec![0xFFFF, 1, 0, 0]; let b = vec![2, 1, 0, 0]; let carries = vec![1, 0, 0]; // no overflow - let witness_values = [a, b, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = UInt::<64, 16, E>::new(|| "b", &mut circuit_builder).unwrap(); - let c = a.add(|| "c", &mut circuit_builder, &b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::from(3) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); - // non-overflow case, the len of carries should be (NUM_CELLS - 1) - assert_eq!(c.carries.unwrap().len(), 3) + let witness_values = [a, b, carries].concat(); + verify::<64, 16, E>(witness_values, None, false); } #[test] fn test_add_w_carries() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - // a = 65535 + 65534 * 2^16 // b = 2 + 1 * 2^16 // c = 1 + 0 * 2^16 + 1 * 2^32 with carries [1, 1, 0, 0] let a = vec![0xFFFF, 0xFFFE, 0, 0]; let b = vec![2, 1, 0, 0]; let carries = vec![1, 1, 0]; // no overflow - let witness_values = [a, b, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = UInt::<64, 16, E>::new(|| "b", &mut circuit_builder).unwrap(); - let c = a.add(|| "c", &mut circuit_builder, &b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); - // non-overflow case, the len of carries should be (NUM_CELLS - 1) - assert_eq!(c.carries.unwrap().len(), 3) + let witness_values = [a, b, carries].concat(); + verify::<64, 16, E>(witness_values, None, false); } #[test] fn test_add_w_overflow() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - // a = 1 + 1 * 2^16 + 0 + 65535 * 2^48 // b = 2 + 1 * 2^16 + 0 + 2 * 2^48 // c = 3 + 2 * 2^16 + 0 + 1 * 2^48 with carries [0, 0, 0, 1] let a = vec![1, 1, 0, 0xFFFF]; let b = vec![2, 1, 0, 2]; let carries = vec![0, 0, 0, 1]; - let witness_values = [a, b, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = UInt::<64, 16, E>::new(|| "b", &mut circuit_builder).unwrap(); - let c = a.add(|| "c", &mut circuit_builder, &b, true).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::from(3) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::from(2) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ONE - ); - // overflow - assert_eq!( - eval_by_expr( - &witness_values, - &challenges, - &c.carries.unwrap().last().unwrap().expr() - ), - E::ONE - ); + let witness_values = [a, b, carries].concat(); + verify::<64, 16, E>(witness_values, None, false); } #[test] fn test_add_const_no_carries() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - // a = 1 + 1 * 2^16 // const b = 2 // c = 3 + 1 * 2^16 with 0 carries let a = vec![1, 1, 0, 0]; let carries = vec![0; 3]; // no overflow - let witness_values = [a, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = Expression::Constant(2.into()); - let c = a.add_const(|| "c", &mut circuit_builder, b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::from(3) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); + let witness_values = [a, carries].concat(); + verify::<64, 16, E>(witness_values, Some(2), false); } #[test] fn test_add_const_w_carries() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - // a = 65535 + 65534 * 2^16 // b = 2 + 1 * 2^16 // c = 1 + 0 * 2^16 + 1 * 2^32 with carries [1, 1, 0, 0] let a = vec![0xFFFF, 0xFFFE, 0, 0]; let carries = vec![1, 1, 0]; // no overflow - let witness_values = [a, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); + let witness_values = [a, carries].concat(); + verify::<64, 16, E>(witness_values, Some(65538), false); + } + + fn verify( + witness_values: Vec, + const_b: Option, + overflow: bool, + ) { + let mut cs = ConstraintSystem::new(|| "test_add"); + let mut circuit_builder = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = Expression::Constant(65538.into()); - let c = a.add_const(|| "c", &mut circuit_builder, b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); + let mut uint_a = UInt::::new(|| "uint_a", &mut circuit_builder).unwrap(); + let uint_c = if const_b.is_none() { + let mut uint_b = UInt::::new(|| "uint_b", &mut circuit_builder).unwrap(); + uint_a + .add(|| "uint_c", &mut circuit_builder, &uint_b, overflow) + .unwrap() + } else { + uint_a + .add_const( + || "uint_c", + &mut circuit_builder, + Expression::Constant(const_b.unwrap().into()), + overflow, + ) + .unwrap() + }; + + let POW_OF_C: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; + let single_wit_size = UInt::::NUM_CELLS; + + let a = &witness_values[0..single_wit_size]; + let mut const_b_pre_allocated = vec![0u64; single_wit_size]; + let b = if const_b.is_none() { + &witness_values[single_wit_size..2 * single_wit_size] + } else { + let b = const_b.unwrap(); + let LIMB_BIT_MASK: u64 = (1 << C) - 1; + const_b_pre_allocated + .iter_mut() + .enumerate() + .for_each(|(i, limb)| *limb = (b >> (C * i)) & LIMB_BIT_MASK); + &const_b_pre_allocated + }; + + let num_witness = if const_b.is_none() { 3 } else { 2 }; + let wit_end_idx = if overflow { + num_witness * single_wit_size + } else { + num_witness * single_wit_size - 1 + }; + let carries = &witness_values[(num_witness - 1) * single_wit_size..wit_end_idx]; + + // limbs cal. + let mut result = vec![0u64; single_wit_size]; + a.iter() + .zip(b) + .enumerate() + .for_each(|(i, (&limb_a, &limb_b))| { + let carry = carries.get(i); + result[i] = limb_a + limb_b; + if i != 0 { + result[i] += carries[i - 1]; + } + if !overflow && carry.is_some() { + result[i] -= carry.unwrap() * POW_OF_C; + } + }); + + let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); + let c_expr = uint_c.expr(); + c_expr.iter().zip(result).for_each(|(c, ret)| { + assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); + }); + + // overflow + if overflow { + assert_eq!( + eval_by_expr( + &wit, + &challenges, + &uint_c.carries.unwrap().last().unwrap().expr() + ), + E::ONE + ); + } else { + // non-overflow case, the len of carries should be (NUM_CELLS - 1) + assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) + } } } @@ -822,12 +721,12 @@ mod tests { // take care carries result.iter_mut().enumerate().for_each(|(i, ret)| { - if !overflow && c_carries.get(i).is_some() { - *ret -= c_carries[i] * POW_OF_C; - } if i != 0 { *ret += c_carries[i - 1]; } + if !overflow && c_carries.get(i).is_some() { + *ret -= c_carries[i] * POW_OF_C; + } }); // verify From 170b214a5df48f35f8502339221a5d9f151f6ce6 Mon Sep 17 00:00:00 2001 From: KimiWu Date: Tue, 10 Sep 2024 14:11:08 +0800 Subject: [PATCH 5/7] test: add M/C!=4 cases for add/add_const --- ceno_zkvm/src/uint/arithmetic.rs | 202 ++++++++++++++++--------------- 1 file changed, 105 insertions(+), 97 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index b3e313687..45cd37b18 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -411,14 +411,13 @@ mod tests { scheme::utils::eval_by_expr, uint::UInt, }; - use ff::Field; use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::Itertools; type E = GoldilocksExt2; #[test] - fn test_add_no_carries() { + fn test_add64_16_no_carries() { // a = 1 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 3 + 2 * 2^16 with 0 carries @@ -430,7 +429,7 @@ mod tests { } #[test] - fn test_add_w_carry() { + fn test_add64_16_w_carry() { // a = 65535 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 1 + 3 * 2^16 with carries [1, 0, 0, 0] @@ -442,7 +441,7 @@ mod tests { } #[test] - fn test_add_w_carries() { + fn test_add64_16_w_carries() { // a = 65535 + 65534 * 2^16 // b = 2 + 1 * 2^16 // c = 1 + 0 * 2^16 + 1 * 2^32 with carries [1, 1, 0, 0] @@ -454,7 +453,7 @@ mod tests { } #[test] - fn test_add_w_overflow() { + fn test_add64_16_w_overflow() { // a = 1 + 1 * 2^16 + 0 + 65535 * 2^48 // b = 2 + 1 * 2^16 + 0 + 2 * 2^48 // c = 3 + 2 * 2^16 + 0 + 1 * 2^48 with carries [0, 0, 0, 1] @@ -466,7 +465,31 @@ mod tests { } #[test] - fn test_add_const_no_carries() { + fn test_add32_16_w_carry() { + // a = 65535 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // c = 1 + 3 * 2^16 with carries [1] + let a = vec![0xFFFF, 1]; + let b = vec![2, 1]; + let carries = vec![1]; // no overflow + let witness_values = [a, b, carries].concat(); + verify::<32, 16, E>(witness_values, None, false); + } + + #[test] + fn test_add32_5_w_carry() { + // a = 31 + // b = 2 + 1 * 2^5 + // c = 1 + 1 * 2^5 with carries [1, 0, 0, 0] + let a = vec![31, 1, 0, 0, 0, 0, 0]; + let b = vec![2, 1, 0, 0, 0, 0, 0]; + let carries = vec![1, 0, 0, 0, 0, 0]; // no overflow + let witness_values = [a, b, carries].concat(); + verify::<32, 5, E>(witness_values, None, false); + } + + #[test] + fn test_add_const64_16_no_carries() { // a = 1 + 1 * 2^16 // const b = 2 // c = 3 + 1 * 2^16 with 0 carries @@ -477,9 +500,9 @@ mod tests { } #[test] - fn test_add_const_w_carries() { + fn test_add_const64_16_w_carries() { // a = 65535 + 65534 * 2^16 - // b = 2 + 1 * 2^16 + // const b = 2 + 1 * 2^16 = 65,538 // c = 1 + 0 * 2^16 + 1 * 2^32 with carries [1, 1, 0, 0] let a = vec![0xFFFF, 0xFFFE, 0, 0]; let carries = vec![1, 1, 0]; // no overflow @@ -487,33 +510,49 @@ mod tests { verify::<64, 16, E>(witness_values, Some(65538), false); } + #[test] + fn test_add_const32_16_w_carry() { + // a = 65535 + 1 * 2^16 + // const b = 2 + 1 * 2^16 = 65,538 + // c = 1 + 3 * 2^16 with carries [1] + let a = vec![0xFFFF, 1]; + let carries = vec![1]; // no overflow + let witness_values = [a, carries].concat(); + verify::<32, 16, E>(witness_values, Some(65538), false); + } + + #[test] + fn test_add_const32_5_w_carry() { + // a = 31 + // const b = 2 + 1 * 2^5 = 34 + // c = 1 + 1 * 2^5 with carries [1, 0, 0, 0] + let a = vec![31, 1, 0, 0, 0, 0, 0]; + let carries = vec![1, 0, 0, 0, 0, 0]; // no overflow + let witness_values = [a, carries].concat(); + verify::<32, 5, E>(witness_values, Some(34), false); + } + fn verify( witness_values: Vec, const_b: Option, overflow: bool, ) { let mut cs = ConstraintSystem::new(|| "test_add"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let mut uint_a = UInt::::new(|| "uint_a", &mut circuit_builder).unwrap(); + let uint_a = UInt::::new(|| "uint_a", &mut cb).unwrap(); let uint_c = if const_b.is_none() { - let mut uint_b = UInt::::new(|| "uint_b", &mut circuit_builder).unwrap(); - uint_a - .add(|| "uint_c", &mut circuit_builder, &uint_b, overflow) - .unwrap() + let uint_b = UInt::::new(|| "uint_b", &mut cb).unwrap(); + uint_a.add(|| "uint_c", &mut cb, &uint_b, overflow).unwrap() } else { + let const_b = Expression::Constant(const_b.unwrap().into()); uint_a - .add_const( - || "uint_c", - &mut circuit_builder, - Expression::Constant(const_b.unwrap().into()), - overflow, - ) + .add_const(|| "uint_c", &mut cb, const_b, overflow) .unwrap() }; - let POW_OF_C: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; + let pow_of_c: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; let single_wit_size = UInt::::NUM_CELLS; let a = &witness_values[0..single_wit_size]; @@ -522,14 +561,16 @@ mod tests { &witness_values[single_wit_size..2 * single_wit_size] } else { let b = const_b.unwrap(); - let LIMB_BIT_MASK: u64 = (1 << C) - 1; + let limb_bit_mask: u64 = (1 << C) - 1; const_b_pre_allocated .iter_mut() .enumerate() - .for_each(|(i, limb)| *limb = (b >> (C * i)) & LIMB_BIT_MASK); + .for_each(|(i, limb)| *limb = (b >> (C * i)) & limb_bit_mask); &const_b_pre_allocated }; + // the num of witness is 3, a, b and c_carries if it's a `add` + // only the num is 2 if it's a `add_const` bcs there is no `b` let num_witness = if const_b.is_none() { 3 } else { 2 }; let wit_end_idx = if overflow { num_witness * single_wit_size @@ -550,26 +591,20 @@ mod tests { result[i] += carries[i - 1]; } if !overflow && carry.is_some() { - result[i] -= carry.unwrap() * POW_OF_C; + result[i] -= carry.unwrap() * pow_of_c; } }); + // verify let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); - let c_expr = uint_c.expr(); - c_expr.iter().zip(result).for_each(|(c, ret)| { + uint_c.expr().iter().zip(result).for_each(|(c, ret)| { assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); }); // overflow if overflow { - assert_eq!( - eval_by_expr( - &wit, - &challenges, - &uint_c.carries.unwrap().last().unwrap().expr() - ), - E::ONE - ); + let carries = uint_c.carries.unwrap().last().unwrap().expr(); + assert_eq!(eval_by_expr(&wit, &challenges, &carries), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) @@ -584,24 +619,11 @@ mod tests { scheme::utils::eval_by_expr, uint::UInt, }; - use ark_std::iterable::Iterable; use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::Itertools; type E = GoldilocksExt2; // 18446744069414584321 - // const POW_OF_C: u64 = 2_usize.pow(16u32) as u64; - - // fn test_data() -> Vec> { - // let d = vec![ - // [256, 256, 0, 0].into(), - // [257, 256, 0, 0].into(), - // [256, 257, 2, 1].into(), - // [1, 2, 1].into(), - // ]; - // d - // } - #[test] fn test_mul64_16_no_carries() { // a = 1 + 1 * 2^16 @@ -688,16 +710,16 @@ mod tests { overflow: bool, ) { let mut cs = ConstraintSystem::new(|| "test_mul"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let mut uint_a = UInt::::new(|| "uint_a", &mut circuit_builder).unwrap(); - let mut uint_b = UInt::::new(|| "uint_b", &mut circuit_builder).unwrap(); + let mut uint_a = UInt::::new(|| "uint_a", &mut cb).unwrap(); + let mut uint_b = UInt::::new(|| "uint_b", &mut cb).unwrap(); let uint_c = uint_a - .mul(|| "uint_c", &mut circuit_builder, &mut uint_b, false) + .mul(|| "uint_c", &mut cb, &mut uint_b, false) .unwrap(); - let POW_OF_C: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; + let pow_of_c: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; let single_wit_size = UInt::::NUM_CELLS; let wit_end_idx = if overflow { 4 * single_wit_size @@ -706,7 +728,7 @@ mod tests { }; let a = &witness_values[0..single_wit_size]; let b = &witness_values[single_wit_size..2 * single_wit_size]; - let c_carries = &witness_values[3 * single_wit_size..wit_end_idx]; + let carries = &witness_values[3 * single_wit_size..wit_end_idx]; // limbs cal. let mut result = vec![0u64; single_wit_size]; @@ -722,29 +744,23 @@ mod tests { // take care carries result.iter_mut().enumerate().for_each(|(i, ret)| { if i != 0 { - *ret += c_carries[i - 1]; + *ret += carries[i - 1]; } - if !overflow && c_carries.get(i).is_some() { - *ret -= c_carries[i] * POW_OF_C; + if !overflow && carries.get(i).is_some() { + *ret -= carries[i] * pow_of_c; } }); // verify - let c_expr = uint_c.expr(); let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); - c_expr.iter().zip(result).for_each(|(c, ret)| { + uint_c.expr().iter().zip(result).for_each(|(c, ret)| { assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); }); + // overflow if overflow { - assert_eq!( - eval_by_expr( - &wit, - &challenges, - &uint_c.carries.unwrap().last().unwrap().expr() - ), - E::ONE - ); + let carries = uint_c.carries.unwrap().last().unwrap().expr(); + assert_eq!(eval_by_expr(&wit, &challenges, &carries), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) @@ -799,18 +815,16 @@ mod tests { .map(|&a| a.into()) .collect_vec(); - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let mut cs = ConstraintSystem::new(|| "test_add_mul"); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut circuit_builder).unwrap(); - let uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut circuit_builder).unwrap(); - let mut uint_c = uint_a - .add(|| "uint_c", &mut circuit_builder, &uint_b, false) - .unwrap(); - let mut uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut circuit_builder).unwrap(); + let uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut cb).unwrap(); + let uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut cb).unwrap(); + let mut uint_c = uint_a.add(|| "uint_c", &mut cb, &uint_b, false).unwrap(); + let mut uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut cb).unwrap(); let uint_e = uint_c - .mul(|| "uint_e", &mut circuit_builder, &mut uint_d, false) + .mul(|| "uint_e", &mut cb, &mut uint_d, false) .unwrap(); uint_e.expr().iter().enumerate().for_each(|(i, ret)| { @@ -866,22 +880,18 @@ mod tests { .map(|&a| a.into()) .collect_vec(); - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let mut cs = ConstraintSystem::new(|| "test_add_mul2"); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut circuit_builder).unwrap(); - let uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut circuit_builder).unwrap(); - let mut uint_c = uint_a - .add(|| "uint_c", &mut circuit_builder, &uint_b, false) - .unwrap(); - let uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut circuit_builder).unwrap(); - let uint_e = UInt::<64, 16, E>::new(|| "uint_e", &mut circuit_builder).unwrap(); - let mut uint_f = uint_d - .add(|| "uint_f", &mut circuit_builder, &uint_e, false) - .unwrap(); + let uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut cb).unwrap(); + let uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut cb).unwrap(); + let mut uint_c = uint_a.add(|| "uint_c", &mut cb, &uint_b, false).unwrap(); + let uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut cb).unwrap(); + let uint_e = UInt::<64, 16, E>::new(|| "uint_e", &mut cb).unwrap(); + let mut uint_f = uint_d.add(|| "uint_f", &mut cb, &uint_e, false).unwrap(); let uint_g = uint_c - .mul(|| "unit_g", &mut circuit_builder, &mut uint_f, false) + .mul(|| "unit_g", &mut cb, &mut uint_f, false) .unwrap(); uint_g.expr().iter().enumerate().for_each(|(i, ret)| { @@ -918,19 +928,17 @@ mod tests { .map(|&a| a.into()) .collect_vec(); - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let mut cs = ConstraintSystem::new(|| "test_mul_add"); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let mut uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut circuit_builder).unwrap(); - let mut uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut circuit_builder).unwrap(); + let mut uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut cb).unwrap(); + let mut uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut cb).unwrap(); let uint_c = uint_a - .mul(|| "uint_c", &mut circuit_builder, &mut uint_b, false) - .unwrap(); - let uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut circuit_builder).unwrap(); - let uint_e = uint_c - .add(|| "uint_e", &mut circuit_builder, &uint_d, false) + .mul(|| "uint_c", &mut cb, &mut uint_b, false) .unwrap(); + let uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut cb).unwrap(); + let uint_e = uint_c.add(|| "uint_e", &mut cb, &uint_d, false).unwrap(); uint_e.expr().iter().enumerate().for_each(|(i, ret)| { // limbs check From b167db66df2723cba6796aca60a3431929b82326 Mon Sep 17 00:00:00 2001 From: KimiWu Date: Tue, 10 Sep 2024 14:25:24 +0800 Subject: [PATCH 6/7] refactor: carry handling in add --- ceno_zkvm/src/uint/arithmetic.rs | 35 +++++++++++++------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 45cd37b18..1a9032078 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -25,7 +25,7 @@ impl UInt { let mut c = UInt::::new_as_empty(); // allocate witness cells and do range checks for carries - c.create_carry_witin(|| "carry", circuit_builder, with_overflow)?; + c.create_carry_witin(|| "add_carry", circuit_builder, with_overflow)?; // perform add operation // c[i] = a[i] + b[i] + carry[i-1] - carry[i] * 2 ^ C @@ -36,25 +36,18 @@ impl UInt { .enumerate() .map(|(i, (a, b))| { let carries = c.carries.as_ref().unwrap(); - let limb_expr = match ( - if i > 0 { carries.get(i - 1) } else { None }, - carries.get(i), - ) { - // first limb - (None, Some(next_carry)) => { - a.clone() + b.clone() - next_carry.expr() * Self::POW_OF_C.into() - } - // assert no overflow - (Some(carry), None) => { - debug_assert!(!with_overflow); - a.clone() + b.clone() + carry.expr() - } - (Some(carry), Some(next_carry)) => { - a.clone() + b.clone() + carry.expr() - - next_carry.expr() * Self::POW_OF_C.into() - } - (None, None) => unreachable!(), - }; + let carry = if i > 0 { carries.get(i - 1) } else { None }; + let next_carry = carries.get(i); + + let mut limb_expr = a.clone() + b.clone(); + if carry.is_some() { + limb_expr = limb_expr.clone() + carry.unwrap().expr(); + } + if next_carry.is_some() { + limb_expr = + limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C.into(); + } + circuit_builder .assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb_expr.clone())?; Ok(limb_expr) @@ -110,7 +103,7 @@ impl UInt { ) -> Result, ZKVMError> { let mut c = UInt::::new(|| "c", circuit_builder)?; // allocate witness cells and do range checks for carries - c.create_carry_witin(|| "carry", circuit_builder, with_overflow)?; + c.create_carry_witin(|| "mul_carry", circuit_builder, with_overflow)?; // We only allow expressions are in monomial form // if any of a or b is in Expression term, it would cause error. From d0c127465c413d88043075e5a45fe4e022b83c11 Mon Sep 17 00:00:00 2001 From: KimiWu Date: Wed, 11 Sep 2024 10:50:36 +0800 Subject: [PATCH 7/7] test: add a overflow case in mul --- ceno_zkvm/src/uint/arithmetic.rs | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 1a9032078..f362ae12d 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -659,6 +659,20 @@ mod tests { verify::<64, 16, E>(witness_values, false); } + #[test] + fn test_mul64_16_w_overflow() { + // 18,446,744,073,709,551,616 + // a = 1 * 2^16 + 1 * 2^32 = 4,295,032,832 + // b = 1 * 2^32 = 4,294,967,296 + // c = 1 * 2^48 + 1 * 2^64 = 18,447,025,548,686,262,272 + let wit_a = vec![0, 1, 1, 0]; + let wit_b = vec![0, 0, 1, 0]; + let wit_c = vec![0, 0, 0, 1]; + let wit_carries = vec![0, 0, 0, 1]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<64, 16, E>(witness_values, true); + } + #[test] fn test_mul64_8_w_carries() { // a = 256 @@ -709,7 +723,7 @@ mod tests { let mut uint_a = UInt::::new(|| "uint_a", &mut cb).unwrap(); let mut uint_b = UInt::::new(|| "uint_b", &mut cb).unwrap(); let uint_c = uint_a - .mul(|| "uint_c", &mut cb, &mut uint_b, false) + .mul(|| "uint_c", &mut cb, &mut uint_b, overflow) .unwrap(); let pow_of_c: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; @@ -752,8 +766,8 @@ mod tests { // overflow if overflow { - let carries = uint_c.carries.unwrap().last().unwrap().expr(); - assert_eq!(eval_by_expr(&wit, &challenges, &carries), E::ONE); + let overflow = uint_c.carries.unwrap().last().unwrap().expr(); + assert_eq!(eval_by_expr(&wit, &challenges, &overflow), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1)