diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index caeb59273..f362ae12d 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -25,7 +25,7 @@ impl UInt { let mut c = UInt::::new_as_empty(); // allocate witness cells and do range checks for carries - c.create_carry_witin(|| "carry", circuit_builder, with_overflow)?; + c.create_carry_witin(|| "add_carry", circuit_builder, with_overflow)?; // perform add operation // c[i] = a[i] + b[i] + carry[i-1] - carry[i] * 2 ^ C @@ -36,25 +36,18 @@ impl UInt { .enumerate() .map(|(i, (a, b))| { let carries = c.carries.as_ref().unwrap(); - let limb_expr = match ( - if i > 0 { carries.get(i - 1) } else { None }, - carries.get(i), - ) { - // first limb - (None, Some(next_carry)) => { - a.clone() + b.clone() - next_carry.expr() * Self::POW_OF_C.into() - } - // assert no overflow - (Some(carry), None) => { - debug_assert!(!with_overflow); - a.clone() + b.clone() + carry.expr() - } - (Some(carry), Some(next_carry)) => { - a.clone() + b.clone() + carry.expr() - - next_carry.expr() * Self::POW_OF_C.into() - } - (None, None) => unreachable!(), - }; + let carry = if i > 0 { carries.get(i - 1) } else { None }; + let next_carry = carries.get(i); + + let mut limb_expr = a.clone() + b.clone(); + if carry.is_some() { + limb_expr = limb_expr.clone() + carry.unwrap().expr(); + } + if next_carry.is_some() { + limb_expr = + limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C.into(); + } + circuit_builder .assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb_expr.clone())?; Ok(limb_expr) @@ -110,7 +103,7 @@ impl UInt { ) -> Result, ZKVMError> { let mut c = UInt::::new(|| "c", circuit_builder)?; // allocate witness cells and do range checks for carries - c.create_carry_witin(|| "carry", circuit_builder, with_overflow)?; + c.create_carry_witin(|| "mul_carry", circuit_builder, with_overflow)?; // We only allow expressions are in monomial form // if any of a or b is in Expression term, it would cause error. @@ -133,44 +126,45 @@ impl UInt { // result check let c_expr = c.expr(); - let c_carries = c.carries.as_ref().unwrap(); - - // TODO #174 - // a_expr[0] * b_expr[0] - c_carry[0] * 2^C = c_expr[0] - circuit_builder.require_equal( - || "c_expr[0]", - a_expr[0].clone() * b_expr[0].clone() - c_carries[0].expr() * Self::POW_OF_C.into(), - c_expr[0].clone(), - )?; - // a_expr[0] * b_expr[1] + a_expr[1] * b_expr[0] - c_carry[1] * 2^C + c_carry[0] = c_expr[1] - circuit_builder.require_equal( - || "c_expr[1]", - a_expr[0].clone() * b_expr[0].clone() - c_carries[1].expr() * Self::POW_OF_C.into() - + c_carries[0].expr(), - c_expr[1].clone(), - )?; - // a_expr[0] * b_expr[2] + a_expr[1] * b_expr[1] + a_expr[2] * b_expr[0] - - // c_carry[2] * 2^C + c_carry[1] = c_expr[2] - circuit_builder.require_equal( - || "c_expr[2]", - a_expr[0].clone() * b_expr[2].clone() - + a_expr[1].clone() * b_expr[1].clone() - + a_expr[2].clone() * b_expr[0].clone() - - c_carries[2].expr() * Self::POW_OF_C.into() - + c_carries[1].expr(), - c_expr[2].clone(), - )?; - // a_expr[0] * b_expr[3] + a_expr[1] * b_expr[2] + a_expr[2] * b_expr[1] + - // a_expr[3] * b_expr[0] - c_carry[3] * 2^C + c_carry[2] = c_expr[3] - let mut target = a_expr[0].clone() * b_expr[3].clone() - + a_expr[1].clone() * b_expr[2].clone() - + a_expr[2].clone() * b_expr[1].clone() - + a_expr[3].clone() * b_expr[0].clone() - + c_carries[2].expr(); - if let Some(overflow) = c_carries.get(3) { - target = target - overflow.expr() * Self::POW_OF_C.into(); - } - circuit_builder.require_equal(|| "c_expr[3]", target, c_expr[3].clone())?; + let carries = c.carries.as_ref().unwrap(); + + // compute the result + let mut result_c: Vec> = Vec::>::with_capacity(Self::NUM_CELLS); + a_expr.iter().enumerate().for_each(|(i, a)| { + b_expr.iter().enumerate().for_each(|(j, b)| { + let idx = i + j; + if idx < Self::NUM_CELLS { + if result_c.get(idx).is_none() { + result_c.push(a.clone() * b.clone()); + } else { + result_c[idx] = result_c[idx].clone() + a.clone() * b.clone(); + } + } + }); + + // take care carries + let carry = if i > 0 { carries.get(i - 1) } else { None }; + let next_carry = carries.get(i); + if carry.is_some() { + result_c[i] = result_c[i].clone() + carry.unwrap().expr(); + } + if next_carry.is_some() { + result_c[i] = + result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C.into(); + } + }); + + // result check + c_expr + .iter() + .zip(result_c) + .enumerate() + .for_each(|(i, (target, result))| { + circuit_builder + .require_equal(|| format!("c_expr{i}"), target.clone(), result) + .unwrap(); + }); + Ok(c) } @@ -410,290 +404,204 @@ mod tests { scheme::utils::eval_by_expr, uint::UInt, }; - use ff::Field; + use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; use itertools::Itertools; type E = GoldilocksExt2; #[test] - fn test_add_no_carries() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - + fn test_add64_16_no_carries() { // a = 1 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 3 + 2 * 2^16 with 0 carries let a = vec![1, 1, 0, 0]; let b = vec![2, 1, 0, 0]; let carries = vec![0; 3]; // no overflow - let witness_values = [a, b, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = UInt::<64, 16, E>::new(|| "b", &mut circuit_builder).unwrap(); - let c = a.add(|| "c", &mut circuit_builder, &b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::from(3) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::from(2) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); - // overflow - assert_eq!( - eval_by_expr( - &witness_values, - &challenges, - &c.carries.unwrap().last().unwrap().expr() - ), - E::ZERO - ); + let witness_values = [a, b, carries].concat(); + verify::<64, 16, E>(witness_values, None, false); } #[test] - fn test_add_w_carry() { - type E = GoldilocksExt2; - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - + fn test_add64_16_w_carry() { // a = 65535 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 1 + 3 * 2^16 with carries [1, 0, 0, 0] let a = vec![0xFFFF, 1, 0, 0]; let b = vec![2, 1, 0, 0]; let carries = vec![1, 0, 0]; // no overflow - let witness_values = [a, b, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = UInt::<64, 16, E>::new(|| "b", &mut circuit_builder).unwrap(); - let c = a.add(|| "c", &mut circuit_builder, &b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::from(3) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); - // overflow - assert_eq!( - eval_by_expr( - &witness_values, - &challenges, - &c.carries.unwrap().last().unwrap().expr() - ), - E::ZERO - ); + let witness_values = [a, b, carries].concat(); + verify::<64, 16, E>(witness_values, None, false); } #[test] - fn test_add_w_carries() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - + fn test_add64_16_w_carries() { // a = 65535 + 65534 * 2^16 // b = 2 + 1 * 2^16 // c = 1 + 0 * 2^16 + 1 * 2^32 with carries [1, 1, 0, 0] let a = vec![0xFFFF, 0xFFFE, 0, 0]; let b = vec![2, 1, 0, 0]; let carries = vec![1, 1, 0]; // no overflow - let witness_values = [a, b, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = UInt::<64, 16, E>::new(|| "b", &mut circuit_builder).unwrap(); - let c = a.add(|| "c", &mut circuit_builder, &b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); - // overflow - assert_eq!( - eval_by_expr( - &witness_values, - &challenges, - &c.carries.unwrap().last().unwrap().expr() - ), - E::ZERO - ); + let witness_values = [a, b, carries].concat(); + verify::<64, 16, E>(witness_values, None, false); } #[test] - fn test_add_w_overflow() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - + fn test_add64_16_w_overflow() { // a = 1 + 1 * 2^16 + 0 + 65535 * 2^48 // b = 2 + 1 * 2^16 + 0 + 2 * 2^48 // c = 3 + 2 * 2^16 + 0 + 1 * 2^48 with carries [0, 0, 0, 1] let a = vec![1, 1, 0, 0xFFFF]; let b = vec![2, 1, 0, 2]; let carries = vec![0, 0, 0, 1]; - let witness_values = [a, b, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + let witness_values = [a, b, carries].concat(); + verify::<64, 16, E>(witness_values, None, false); + } - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = UInt::<64, 16, E>::new(|| "b", &mut circuit_builder).unwrap(); - let c = a.add(|| "c", &mut circuit_builder, &b, true).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::from(3) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::from(2) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ONE - ); - // overflow - assert_eq!( - eval_by_expr( - &witness_values, - &challenges, - &c.carries.unwrap().last().unwrap().expr() - ), - E::ONE - ); + #[test] + fn test_add32_16_w_carry() { + // a = 65535 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // c = 1 + 3 * 2^16 with carries [1] + let a = vec![0xFFFF, 1]; + let b = vec![2, 1]; + let carries = vec![1]; // no overflow + let witness_values = [a, b, carries].concat(); + verify::<32, 16, E>(witness_values, None, false); } #[test] - fn test_add_const_no_carries() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + fn test_add32_5_w_carry() { + // a = 31 + // b = 2 + 1 * 2^5 + // c = 1 + 1 * 2^5 with carries [1, 0, 0, 0] + let a = vec![31, 1, 0, 0, 0, 0, 0]; + let b = vec![2, 1, 0, 0, 0, 0, 0]; + let carries = vec![1, 0, 0, 0, 0, 0]; // no overflow + let witness_values = [a, b, carries].concat(); + verify::<32, 5, E>(witness_values, None, false); + } + #[test] + fn test_add_const64_16_no_carries() { // a = 1 + 1 * 2^16 // const b = 2 // c = 3 + 1 * 2^16 with 0 carries let a = vec![1, 1, 0, 0]; let carries = vec![0; 3]; // no overflow - let witness_values = [a, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); - let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = Expression::Constant(2.into()); - let c = a.add_const(|| "c", &mut circuit_builder, b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::from(3) - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); + let witness_values = [a, carries].concat(); + verify::<64, 16, E>(witness_values, Some(2), false); } #[test] - fn test_add_const_w_carries() { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); - + fn test_add_const64_16_w_carries() { // a = 65535 + 65534 * 2^16 - // b = 2 + 1 * 2^16 + // const b = 2 + 1 * 2^16 = 65,538 // c = 1 + 0 * 2^16 + 1 * 2^32 with carries [1, 1, 0, 0] let a = vec![0xFFFF, 0xFFFE, 0, 0]; let carries = vec![1, 1, 0]; // no overflow - let witness_values = [a, carries] - .concat() - .iter() - .map(|&a| a.into()) - .collect_vec(); + let witness_values = [a, carries].concat(); + verify::<64, 16, E>(witness_values, Some(65538), false); + } + + #[test] + fn test_add_const32_16_w_carry() { + // a = 65535 + 1 * 2^16 + // const b = 2 + 1 * 2^16 = 65,538 + // c = 1 + 3 * 2^16 with carries [1] + let a = vec![0xFFFF, 1]; + let carries = vec![1]; // no overflow + let witness_values = [a, carries].concat(); + verify::<32, 16, E>(witness_values, Some(65538), false); + } + + #[test] + fn test_add_const32_5_w_carry() { + // a = 31 + // const b = 2 + 1 * 2^5 = 34 + // c = 1 + 1 * 2^5 with carries [1, 0, 0, 0] + let a = vec![31, 1, 0, 0, 0, 0, 0]; + let carries = vec![1, 0, 0, 0, 0, 0]; // no overflow + let witness_values = [a, carries].concat(); + verify::<32, 5, E>(witness_values, Some(34), false); + } + + fn verify( + witness_values: Vec, + const_b: Option, + overflow: bool, + ) { + let mut cs = ConstraintSystem::new(|| "test_add"); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let a = UInt::<64, 16, E>::new(|| "a", &mut circuit_builder).unwrap(); - let b = Expression::Constant(65538.into()); - let c = a.add_const(|| "c", &mut circuit_builder, b, false).unwrap(); - - // verify limb_c[] = limb_a[] + limb_b[] - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[0]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[1]), - E::ZERO - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[2]), - E::ONE - ); - assert_eq!( - eval_by_expr(&witness_values, &challenges, &c.expr()[3]), - E::ZERO - ); + let uint_a = UInt::::new(|| "uint_a", &mut cb).unwrap(); + let uint_c = if const_b.is_none() { + let uint_b = UInt::::new(|| "uint_b", &mut cb).unwrap(); + uint_a.add(|| "uint_c", &mut cb, &uint_b, overflow).unwrap() + } else { + let const_b = Expression::Constant(const_b.unwrap().into()); + uint_a + .add_const(|| "uint_c", &mut cb, const_b, overflow) + .unwrap() + }; + + let pow_of_c: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; + let single_wit_size = UInt::::NUM_CELLS; + + let a = &witness_values[0..single_wit_size]; + let mut const_b_pre_allocated = vec![0u64; single_wit_size]; + let b = if const_b.is_none() { + &witness_values[single_wit_size..2 * single_wit_size] + } else { + let b = const_b.unwrap(); + let limb_bit_mask: u64 = (1 << C) - 1; + const_b_pre_allocated + .iter_mut() + .enumerate() + .for_each(|(i, limb)| *limb = (b >> (C * i)) & limb_bit_mask); + &const_b_pre_allocated + }; + + // the num of witness is 3, a, b and c_carries if it's a `add` + // only the num is 2 if it's a `add_const` bcs there is no `b` + let num_witness = if const_b.is_none() { 3 } else { 2 }; + let wit_end_idx = if overflow { + num_witness * single_wit_size + } else { + num_witness * single_wit_size - 1 + }; + let carries = &witness_values[(num_witness - 1) * single_wit_size..wit_end_idx]; + + // limbs cal. + let mut result = vec![0u64; single_wit_size]; + a.iter() + .zip(b) + .enumerate() + .for_each(|(i, (&limb_a, &limb_b))| { + let carry = carries.get(i); + result[i] = limb_a + limb_b; + if i != 0 { + result[i] += carries[i - 1]; + } + if !overflow && carry.is_some() { + result[i] -= carry.unwrap() * pow_of_c; + } + }); + + // verify + let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); + uint_c.expr().iter().zip(result).for_each(|(c, ret)| { + assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); + }); + + // overflow + if overflow { + let carries = uint_c.carries.unwrap().last().unwrap().expr(); + assert_eq!(eval_by_expr(&wit, &challenges, &carries), E::ONE); + } else { + // non-overflow case, the len of carries should be (NUM_CELLS - 1) + assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) + } } } @@ -709,35 +617,34 @@ mod tests { use itertools::Itertools; type E = GoldilocksExt2; // 18446744069414584321 - const POW_OF_C: u64 = 2_usize.pow(16u32) as u64; #[test] - fn test_mul_no_carries() { + fn test_mul64_16_no_carries() { // a = 1 + 1 * 2^16 // b = 2 + 1 * 2^16 // c = 2 + 3 * 2^16 + 1 * 2^32 = 4,295,163,906 let wit_a = vec![1, 1, 0, 0]; let wit_b = vec![2, 1, 0, 0]; let wit_c = vec![2, 3, 1, 0]; - let wit_carries = vec![0, 0, 0, 0]; + let wit_carries = vec![0, 0, 0]; let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); - verify::(witness_values, false); + verify::<64, 16, E>(witness_values, false); } #[test] - fn test_mul_w_carry() { + fn test_mul64_16_w_carry() { // a = 256 + 1 * 2^16 // b = 257 + 1 * 2^16 // c = 256 + 514 * 2^16 + 1 * 2^32 = 4,328,653,056 let wit_a = vec![256, 1, 0, 0]; let wit_b = vec![257, 1, 0, 0]; let wit_c = vec![256, 514, 1, 0]; - let wit_carries = vec![1, 0, 0, 0]; + let wit_carries = vec![1, 0, 0]; let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); - verify::(witness_values, false); + verify::<64, 16, E>(witness_values, false); } #[test] - fn test_mul_w_carries() { + fn test_mul64_16_w_carries() { // a = 256 + 256 * 2^16 = 16,777,472 // b = 257 + 256 * 2^16 = 16,777,473 // c = 256 + 257 * 2^16 + 2 * 2^32 + 1 * 2^48 = 281,483,583,488,256 @@ -747,51 +654,124 @@ mod tests { // ==> [256 + 1 * (2^16), 256 + 2 * (2^16), 0 + 1 * (2^16), 0] // so we get wit_c = [256, 256, 0, 0] and carries = [1, 2, 1, 0] let wit_c = vec![256, 257, 2, 1]; - let wit_carries = vec![1, 2, 1, 0]; + let wit_carries = vec![1, 2, 1]; let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); - verify::(witness_values, true); + verify::<64, 16, E>(witness_values, false); } - fn verify(witness_values: Vec, overflow: bool) { - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + #[test] + fn test_mul64_16_w_overflow() { + // 18,446,744,073,709,551,616 + // a = 1 * 2^16 + 1 * 2^32 = 4,295,032,832 + // b = 1 * 2^32 = 4,294,967,296 + // c = 1 * 2^48 + 1 * 2^64 = 18,447,025,548,686,262,272 + let wit_a = vec![0, 1, 1, 0]; + let wit_b = vec![0, 0, 1, 0]; + let wit_c = vec![0, 0, 0, 1]; + let wit_carries = vec![0, 0, 0, 1]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<64, 16, E>(witness_values, true); + } + + #[test] + fn test_mul64_8_w_carries() { + // a = 256 + // b = 257 + // c = 254 + 1 * 2^16 = 510 + let wit_a = vec![255, 0, 0, 0, 0, 0, 0, 0]; + let wit_b = vec![2, 0, 0, 0, 0, 0, 0, 0]; + let wit_c = vec![254, 1, 0, 0, 0, 0, 0, 0]; + let wit_carries = vec![1, 0, 0, 0, 0, 0, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<64, 8, E>(witness_values, false); + } + + #[test] + fn test_mul32_16_w_carries() { + // a = 256 + // b = 257 + // c = 256 + 1 * 2^16 = 65,792 + let wit_a = vec![256, 0]; + let wit_b = vec![257, 0]; + let wit_c = vec![256, 1]; + let wit_carries = vec![1, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<32, 16, E>(witness_values, false); + } + + #[test] + fn test_mul32_5_w_carries() { + // a = 31 + // b = 2 + // c = 30 + 1 * 2^8 = 62 + let wit_a = vec![31, 0, 0, 0, 0, 0, 0]; + let wit_b = vec![2, 0, 0, 0, 0, 0, 0]; + let wit_c = vec![30, 1, 0, 0, 0, 0, 0]; + let wit_carries = vec![1, 0, 0, 0, 0, 0]; + let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); + verify::<32, 5, E>(witness_values, false); + } + + fn verify( + witness_values: Vec, + overflow: bool, + ) { + let mut cs = ConstraintSystem::new(|| "test_mul"); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let mut uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut circuit_builder).unwrap(); - let mut uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut circuit_builder).unwrap(); + let mut uint_a = UInt::::new(|| "uint_a", &mut cb).unwrap(); + let mut uint_b = UInt::::new(|| "uint_b", &mut cb).unwrap(); let uint_c = uint_a - .mul(|| "uint_c", &mut circuit_builder, &mut uint_b, false) + .mul(|| "uint_c", &mut cb, &mut uint_b, overflow) .unwrap(); - let a = &witness_values[0..4]; - let b = &witness_values[4..8]; - let c_carries = &witness_values[12..16]; + let pow_of_c: u64 = 2_usize.pow(UInt::::MAX_CELL_BIT_WIDTH as u32) as u64; + let single_wit_size = UInt::::NUM_CELLS; + let wit_end_idx = if overflow { + 4 * single_wit_size + } else { + 4 * single_wit_size - 1 + }; + let a = &witness_values[0..single_wit_size]; + let b = &witness_values[single_wit_size..2 * single_wit_size]; + let carries = &witness_values[3 * single_wit_size..wit_end_idx]; // limbs cal. - let t0 = a[0] * b[0] - c_carries[0] * POW_OF_C; - let t1 = a[0] * b[1] + a[1] * b[0] - c_carries[1] * POW_OF_C + c_carries[0]; - let t2 = - a[0] * b[2] + a[1] * b[1] + a[2] * b[0] - c_carries[2] * POW_OF_C + c_carries[1]; - let t3 = a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0] - - c_carries[3] * POW_OF_C - + c_carries[2]; + let mut result = vec![0u64; single_wit_size]; + a.iter().enumerate().for_each(|(i, a_limb)| { + b.iter().enumerate().for_each(|(j, b_limb)| { + let idx = i + j; + if idx < single_wit_size { + result[idx] += a_limb * b_limb; + } + }); + }); + + // take care carries + result.iter_mut().enumerate().for_each(|(i, ret)| { + if i != 0 { + *ret += carries[i - 1]; + } + if !overflow && carries.get(i).is_some() { + *ret -= carries[i] * pow_of_c; + } + }); // verify - let c_expr = uint_c.expr(); - let w: Vec = witness_values.iter().map(|&a| a.into()).collect_vec(); - assert_eq!(eval_by_expr(&w, &challenges, &c_expr[0]), E::from(t0)); - assert_eq!(eval_by_expr(&w, &challenges, &c_expr[1]), E::from(t1)); - assert_eq!(eval_by_expr(&w, &challenges, &c_expr[2]), E::from(t2)); - assert_eq!(eval_by_expr(&w, &challenges, &c_expr[3]), E::from(t3)); + let wit: Vec = witness_values.iter().map(|&w| w.into()).collect_vec(); + uint_c.expr().iter().zip(result).for_each(|(c, ret)| { + assert_eq!(eval_by_expr(&wit, &challenges, c), E::from(ret)); + }); + // overflow - assert_eq!( - eval_by_expr( - &w, - &challenges, - &uint_c.carries.unwrap().last().unwrap().expr() - ), - if overflow { E::ONE } else { E::ZERO } - ); + if overflow { + let overflow = uint_c.carries.unwrap().last().unwrap().expr(); + assert_eq!(eval_by_expr(&wit, &challenges, &overflow), E::ONE); + } else { + // non-overflow case, the len of carries should be (NUM_CELLS - 1) + assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) + } } } @@ -842,18 +822,16 @@ mod tests { .map(|&a| a.into()) .collect_vec(); - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let mut cs = ConstraintSystem::new(|| "test_add_mul"); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut circuit_builder).unwrap(); - let uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut circuit_builder).unwrap(); - let mut uint_c = uint_a - .add(|| "uint_c", &mut circuit_builder, &uint_b, false) - .unwrap(); - let mut uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut circuit_builder).unwrap(); + let uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut cb).unwrap(); + let uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut cb).unwrap(); + let mut uint_c = uint_a.add(|| "uint_c", &mut cb, &uint_b, false).unwrap(); + let mut uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut cb).unwrap(); let uint_e = uint_c - .mul(|| "uint_e", &mut circuit_builder, &mut uint_d, false) + .mul(|| "uint_e", &mut cb, &mut uint_d, false) .unwrap(); uint_e.expr().iter().enumerate().for_each(|(i, ret)| { @@ -909,22 +887,18 @@ mod tests { .map(|&a| a.into()) .collect_vec(); - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let mut cs = ConstraintSystem::new(|| "test_add_mul2"); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut circuit_builder).unwrap(); - let uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut circuit_builder).unwrap(); - let mut uint_c = uint_a - .add(|| "uint_c", &mut circuit_builder, &uint_b, false) - .unwrap(); - let uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut circuit_builder).unwrap(); - let uint_e = UInt::<64, 16, E>::new(|| "uint_e", &mut circuit_builder).unwrap(); - let mut uint_f = uint_d - .add(|| "uint_f", &mut circuit_builder, &uint_e, false) - .unwrap(); + let uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut cb).unwrap(); + let uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut cb).unwrap(); + let mut uint_c = uint_a.add(|| "uint_c", &mut cb, &uint_b, false).unwrap(); + let uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut cb).unwrap(); + let uint_e = UInt::<64, 16, E>::new(|| "uint_e", &mut cb).unwrap(); + let mut uint_f = uint_d.add(|| "uint_f", &mut cb, &uint_e, false).unwrap(); let uint_g = uint_c - .mul(|| "unit_g", &mut circuit_builder, &mut uint_f, false) + .mul(|| "unit_g", &mut cb, &mut uint_f, false) .unwrap(); uint_g.expr().iter().enumerate().for_each(|(i, ret)| { @@ -961,19 +935,17 @@ mod tests { .map(|&a| a.into()) .collect_vec(); - let mut cs = ConstraintSystem::new(|| "test"); - let mut circuit_builder = CircuitBuilder::::new(&mut cs); + let mut cs = ConstraintSystem::new(|| "test_mul_add"); + let mut cb = CircuitBuilder::::new(&mut cs); let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); - let mut uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut circuit_builder).unwrap(); - let mut uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut circuit_builder).unwrap(); + let mut uint_a = UInt::<64, 16, E>::new(|| "uint_a", &mut cb).unwrap(); + let mut uint_b = UInt::<64, 16, E>::new(|| "uint_b", &mut cb).unwrap(); let uint_c = uint_a - .mul(|| "uint_c", &mut circuit_builder, &mut uint_b, false) - .unwrap(); - let uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut circuit_builder).unwrap(); - let uint_e = uint_c - .add(|| "uint_e", &mut circuit_builder, &uint_d, false) + .mul(|| "uint_c", &mut cb, &mut uint_b, false) .unwrap(); + let uint_d = UInt::<64, 16, E>::new(|| "uint_d", &mut cb).unwrap(); + let uint_e = uint_c.add(|| "uint_e", &mut cb, &uint_d, false).unwrap(); uint_e.expr().iter().enumerate().for_each(|(i, ret)| { // limbs check