diff --git a/ff/src/biginteger/mod.rs b/ff/src/biginteger/mod.rs index e0d025965..46a236fcb 100644 --- a/ff/src/biginteger/mod.rs +++ b/ff/src/biginteger/mod.rs @@ -444,10 +444,9 @@ impl BigInteger for BigInt { } #[inline] - fn mul_low(&mut self, other: &Self) { + fn mul_low(&self, other: &Self) -> Self { if self.is_zero() || other.is_zero() { - *self = Self::zero(); - return; + return Self::zero(); } let mut res = Self::zero(); @@ -460,7 +459,7 @@ impl BigInteger for BigInt { carry = 0; } - *self = res + res } #[inline] @@ -1099,15 +1098,13 @@ pub trait BigInteger: /// // Basic /// let mut a = B::from(42u64); /// let b = B::from(3u64); - /// a.mul_low(&b); - /// assert_eq!(a, B::from(126u64)); + /// assert_eq!(a.mul_low(&b), B::from(126u64)); /// /// // Edge-Case /// let mut zero = B::from(0u64); - /// zero.mul_low(&B::from(5u64)); - /// assert_eq!(zero, B::from(0u64)); + /// assert_eq!(zero.mul_low(&B::from(5u64)), B::from(0u64)); /// ``` - fn mul_low(&mut self, other: &Self); + fn mul_low(&self, other: &Self) -> Self; /// Multiplies this [`BigInteger`] by another `BigInteger`, returning the high bits of the result. /// diff --git a/ff/src/biginteger/tests.rs b/ff/src/biginteger/tests.rs index 83767aaf5..dad4730c1 100644 --- a/ff/src/biginteger/tests.rs +++ b/ff/src/biginteger/tests.rs @@ -51,29 +51,21 @@ fn biginteger_arithmetic_test(a: B, b: B, zero: B, max: B) { assert_eq!(a_mul2, a_plus_a); // a * 1 = a - let mut a_mul = a; - a_mul.mul_low(&B::from(1u64)); - assert_eq!(a_mul, a); + assert_eq!(a.mul_low(&B::from(1u64)), a); // a * 2 = a - a_mul.mul_low(&B::from(2u64)); - assert_eq!(a_mul, a_plus_a); + assert_eq!(a.mul_low(&B::from(2u64)), a_plus_a); - // a * 2 * b = b * 2 * a - a_mul.mul_low(&b); - let mut b_mul = b; - b_mul.mul_low(&B::from(2u64)); - b_mul.mul_low(&a); - assert_eq!(a_mul, b_mul); + // a * b = b * a + assert_eq!(a.mul_low(&b), b.mul_low(&a)); // a * 2 * b * 0 = 0 - a_mul.mul_low(&zero); - assert!(a_mul.is_zero()); + assert!(a.mul_low(&zero).is_zero()); // a * 2 * ... * 2 = a * 2^n let mut a_mul_n = a; for _ in 0..20 { - a_mul_n.mul_low(&B::from(2u64)); + a_mul_n = a_mul_n.mul_low(&B::from(2u64)); } assert_eq!(a_mul_n, a << 20);