diff --git a/src/utils/sqrt.rs b/src/utils/sqrt.rs index 5c897ef..0fcfc9f 100644 --- a/src/utils/sqrt.rs +++ b/src/utils/sqrt.rs @@ -1,6 +1,14 @@ use num_bigint::{BigInt, ToBigInt}; use num_traits::{ToPrimitive, Zero}; +/// Computes floor(sqrt(value)) +/// +/// # Arguments +/// +/// * `value`: The value for which to compute the square root, rounded down +/// +/// returns: BigInt +/// fn sqrt(value: &BigInt) -> BigInt { assert!(*value >= Zero::zero(), "NEGATIVE"); @@ -25,3 +33,36 @@ fn sqrt(value: &BigInt) -> BigInt { z } + +#[cfg(test)] +mod tests { + use super::*; + use num_traits::Num; + + #[test] + fn test_sqrt_0_1000() { + for i in 0..1000 { + let sqrt_i = sqrt(&BigInt::from(i)); + assert_eq!(sqrt_i, BigInt::from((i as f64).sqrt().floor() as i64)); + } + } + + #[test] + fn test_sqrt_2_powers() { + for i in 0..256 { + let root = BigInt::from(2).pow(i as u32); + let root_squared = &root * &root; + assert_eq!(sqrt(&root_squared), root); + } + } + + #[test] + fn test_sqrt_max_uint256() { + let max_uint256_string = + "115792089237316195423570985008687907853269984665640564039457584007913129639935"; + let max_uint256 = BigInt::from_str_radix(max_uint256_string, 10).unwrap(); + let expected_sqrt = + BigInt::from_str_radix("340282366920938463463374607431768211455", 10).unwrap(); + assert_eq!(sqrt(&max_uint256), expected_sqrt); + } +}