diff --git a/src/fns/unconstrained_helpers.nr b/src/fns/unconstrained_helpers.nr index 0afcfb40..28eee2f3 100644 --- a/src/fns/unconstrained_helpers.nr +++ b/src/fns/unconstrained_helpers.nr @@ -39,12 +39,11 @@ pub(crate) unconstrained fn __validate_gt_remainder( let mut b_u60: U60Repr = U60Repr::from(rhs); let underflow = b_u60.gte(a_u60); - b_u60 += U60Repr::one(); assert(underflow == false, "BigNum::validate_gt check fails"); let mut result_u60: U60Repr = U60Repr { limbs: [0; 2 * N] }; let mut carry_in: u64 = 0; - let mut borrow_in: u64 = 0; + let mut borrow_in: u64 = 1; let mut borrow_flags: [bool; N] = [false; N]; let mut carry_flags: [bool; N] = [false; N]; for i in 0..2 * N { diff --git a/src/fns/unconstrained_ops.nr b/src/fns/unconstrained_ops.nr index 68e12df0..86711fe7 100644 --- a/src/fns/unconstrained_ops.nr +++ b/src/fns/unconstrained_ops.nr @@ -171,34 +171,40 @@ pub(crate) unconstrained fn __udiv_mod( let mut divisor_u60: U60Repr = U60Repr::from(divisor); let b = divisor_u60; - let mut bit_difference = remainder_u60.get_msb() - divisor_u60.get_msb(); + let numerator_msb = remainder_u60.get_msb(); + let divisor_msb = divisor_u60.get_msb(); + if divisor_msb > numerator_msb { + ([0; N], numerator) + } else { + let mut bit_difference = numerator_msb - divisor_msb; - let mut accumulator_u60: U60Repr = U60Repr::one(); - divisor_u60 = divisor_u60.shl(bit_difference); - accumulator_u60 = accumulator_u60.shl(bit_difference); + let mut accumulator_u60: U60Repr = U60Repr::one(); + divisor_u60 = divisor_u60.shl(bit_difference); + accumulator_u60 = accumulator_u60.shl(bit_difference); - if (divisor_u60.gte(remainder_u60 + U60Repr::one())) { - divisor_u60.shr1(); - accumulator_u60.shr1(); - } - for _ in 0..(N * 120) { - if (remainder_u60.gte(b) == false) { - break; + if (divisor_u60.gte(remainder_u60 + U60Repr::one())) { + divisor_u60.shr1(); + accumulator_u60.shr1(); } + for _ in 0..(N * 120) { + if (remainder_u60.gte(b) == false) { + break; + } - // we've shunted 'divisor' up to have the same bit length as our remainder. - // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b - if (remainder_u60.gte(divisor_u60)) { - remainder_u60 -= divisor_u60; - // we can use OR here instead of +, as - // accumulator is always a nice power of two - quotient_u60 = quotient_u60 + accumulator_u60; + // we've shunted 'divisor' up to have the same bit length as our remainder. + // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b + if (remainder_u60.gte(divisor_u60)) { + remainder_u60 -= divisor_u60; + // we can use OR here instead of +, as + // accumulator is always a nice power of two + quotient_u60 = quotient_u60 + accumulator_u60; + } + divisor_u60.shr1(); // >>= 1; + accumulator_u60.shr1(); // >>= 1; } - divisor_u60.shr1(); // >>= 1; - accumulator_u60.shr1(); // >>= 1; - } - (U60Repr::into(quotient_u60), U60Repr::into(remainder_u60)) + (U60Repr::into(quotient_u60), U60Repr::into(remainder_u60)) + } } pub(crate) unconstrained fn __invmod( diff --git a/src/tests/bignum_test.nr b/src/tests/bignum_test.nr index ffc8f524..c16c136d 100644 --- a/src/tests/bignum_test.nr +++ b/src/tests/bignum_test.nr @@ -632,6 +632,21 @@ fn test_udiv_mod_U256() { assert(product == a); } +#[test] +fn test_1_udiv_mod_2() { + let _0: U256 = BigNum::new(); + let _1: U256 = BigNum::one(); + assert(_1.udiv_mod(_1 + _1) == (_0, _1)); +} + +#[test] +fn test_20_udiv_mod_11() { + let _1: U256 = BigNum::one(); + let _2_POW_120: U256 = BigNum::from_slice([0, 1, 0]); + let _2_POW_121: U256 = BigNum::from_slice([0, 2, 0]); + assert(_2_POW_121.udiv_mod(_2_POW_120 + _1) == (_1, _2_POW_120 - _1)); +} + // // N.B. witness generation times make these tests take ~15 minutes each! Uncomment at your peril // #[test] // fn test_div_2048() {