diff --git a/ff/src/biginteger/arithmetic.rs b/ff/src/biginteger/arithmetic.rs index 227c75593..bf6cefa0c 100644 --- a/ff/src/biginteger/arithmetic.rs +++ b/ff/src/biginteger/arithmetic.rs @@ -1,4 +1,3 @@ -use ark_std::Zero; use ark_std::{vec, vec::*}; macro_rules! adc { @@ -153,52 +152,53 @@ pub fn mac_with_carry(a: u64, b: u64, c: u64, carry: &mut u64) -> u64 { /// Compute the NAF (non-adjacent form) of num pub fn find_naf(num: &[u64]) -> Vec { - let is_zero = |num: &[u64]| num.iter().all(Zero::is_zero); + let mut num = num.to_vec(); + let mut res = vec![]; + + // Helper functions for arithmetic operations + // Check if the number is non-zero + let is_non_zero = |num: &[u64]| num.iter().any(|&x| x != 0); + // Check if the number is odd let is_odd = |num: &[u64]| num[0] & 1 == 1; + // Subtract a value `z` without borrow propagation let sub_noborrow = |num: &mut [u64], z: u64| { - let mut other = vec![0u64; num.len()]; - other[0] = z; - let mut borrow = 0; - - for (a, b) in num.iter_mut().zip(other) { - borrow = sbb(a, b, borrow); - } + num.iter_mut() + .zip(ark_std::iter::once(z).chain(ark_std::iter::repeat(0))) + .fold(0, |borrow, (a, b)| sbb(a, b, borrow)); }; + // Add a value `z` without carry propagation let add_nocarry = |num: &mut [u64], z: u64| { - let mut other = vec![0u64; num.len()]; - other[0] = z; - let mut carry = 0; - - for (a, b) in num.iter_mut().zip(other) { - carry = adc(a, b, carry); - } + num.iter_mut() + .zip(ark_std::iter::once(z).chain(ark_std::iter::repeat(0))) + .fold(0, |carry, (a, b)| adc(a, b, carry)); }; + // Perform an in-place division of the number by 2 let div2 = |num: &mut [u64]| { - let mut t = 0; - for i in num.iter_mut().rev() { - let t2 = *i << 63; - *i >>= 1; - *i |= t; - t = t2; - } + num.iter_mut().rev().fold(0, |carry, x| { + let next_carry = *x << 63; + *x = (*x >> 1) | carry; + next_carry + }); }; - let mut num = num.to_vec(); - let mut res = vec![]; - - while !is_zero(&num) { - let z: i8; - if is_odd(&num) { - z = 2 - (num[0] % 4) as i8; + // Main loop for NAF computation + while is_non_zero(&num) { + // Determine the current digit of the NAF representation + let z = if is_odd(&num) { + let z = 2 - (num[0] % 4) as i8; if z >= 0 { - sub_noborrow(&mut num, z as u64) + sub_noborrow(&mut num, z as u64); } else { - add_nocarry(&mut num, (-z) as u64) + add_nocarry(&mut num, (-z) as u64); } + z } else { - z = 0; - } + 0 + }; + + // Append the digit to the result res.push(z); + // Divide the number by 2 for the next iteration div2(&mut num); } @@ -454,4 +454,47 @@ mod tests { assert_eq!(test, test_expected); } } + + #[test] + fn test_find_naf_zero() { + // Test for zero input + let naf = find_naf(&[0]); + assert!(naf.is_empty()); + } + + #[test] + fn test_find_naf_single_digit() { + // Test for small numbers + assert_eq!(find_naf(&[1]), vec![1]); + assert_eq!(find_naf(&[2]), vec![0, 1]); + assert_eq!(find_naf(&[3]), vec![-1, 0, 1]); + assert_eq!(find_naf(&[4]), vec![0, 0, 1]); + } + + #[test] + fn test_find_naf_large_number() { + // Test for a larger number + assert_eq!(find_naf(&[13]), vec![1, 0, -1, 0, 1]); + } + + #[test] + fn test_find_naf_multiple_blocks() { + // Test multi-block number (simulate large numbers split across blocks) + let num = [0, 1]; + assert_eq!( + find_naf(&num), + vec![ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1 + ] + ); + } + + #[test] + fn test_find_naf_edge_cases() { + // Test edge cases + let naf = find_naf(&[u64::MAX]); + assert!(naf.len() > 0); + } }