Skip to content

Commit

Permalink
Merge #56
Browse files Browse the repository at this point in the history
56: Implement Roots for BigInt and BigUint r=cuviper a=mancabizjak

Supersedes #51 .

Since there is now a `Roots` trait with `sqrt`, `cbrt` and `nth_root` methods in the `num-integer` crate, this PR implements it for `BigInt` and `BigUint` types. I also added inherent methods on both types to allow the users access to all these functions without having to import `Roots`.

PS: `nth_root` currently  uses `num_traits::pow`. Should we perhaps wait for #54 to get merged, and then replace the call to use the new `pow::Pow` implementation on `BigUint`?

Co-authored-by: Manca Bizjak <[email protected]>
  • Loading branch information
bors[bot] and mancabizjak committed Jul 19, 2018
2 parents 86e019b + 1d45ca9 commit c504fa8
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ name = "shootout-pidigits"
[dependencies]

[dependencies.num-integer]
version = "0.1.38"
version = "0.1.39"
default-features = false

[dependencies.num-traits]
Expand Down
25 changes: 25 additions & 0 deletions benches/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
extern crate test;
extern crate num_bigint;
extern crate num_traits;
extern crate num_integer;
extern crate rand;

use std::mem::replace;
Expand Down Expand Up @@ -342,3 +343,27 @@ fn modpow_even(b: &mut Bencher) {

b.iter(|| base.modpow(&e, &m));
}

#[bench]
fn roots_sqrt(b: &mut Bencher) {
let mut rng = get_rng();
let x = rng.gen_biguint(2048);

b.iter(|| x.sqrt());
}

#[bench]
fn roots_cbrt(b: &mut Bencher) {
let mut rng = get_rng();
let x = rng.gen_biguint(2048);

b.iter(|| x.cbrt());
}

#[bench]
fn roots_nth_100(b: &mut Bencher) {
let mut rng = get_rng();
let x = rng.gen_biguint(2048);

b.iter(|| x.nth_root(100));
}
39 changes: 38 additions & 1 deletion src/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::iter::{Product, Sum};
#[cfg(feature = "serde")]
use serde;

use integer::Integer;
use integer::{Integer, Roots};
use traits::{ToPrimitive, FromPrimitive, Num, CheckedAdd, CheckedSub,
CheckedMul, CheckedDiv, Signed, Zero, One};

Expand Down Expand Up @@ -1802,6 +1802,25 @@ impl Integer for BigInt {
}
}

impl Roots for BigInt {
fn nth_root(&self, n: u32) -> Self {
assert!(!(self.is_negative() && n.is_even()),
"root of degree {} is imaginary", n);

BigInt::from_biguint(self.sign, self.data.nth_root(n))
}

fn sqrt(&self) -> Self {
assert!(!self.is_negative(), "square root is imaginary");

BigInt::from_biguint(self.sign, self.data.sqrt())
}

fn cbrt(&self) -> Self {
BigInt::from_biguint(self.sign, self.data.cbrt())
}
}

impl ToPrimitive for BigInt {
#[inline]
fn to_i64(&self) -> Option<i64> {
Expand Down Expand Up @@ -2538,6 +2557,24 @@ impl BigInt {
};
BigInt::from_biguint(sign, mag)
}

/// Returns the truncated principal square root of `self` --
/// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt).
pub fn sqrt(&self) -> Self {
Roots::sqrt(self)
}

/// Returns the truncated principal cube root of `self` --
/// see [Roots::cbrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.cbrt).
pub fn cbrt(&self) -> Self {
Roots::cbrt(self)
}

/// Returns the truncated principal `n`th root of `self` --
/// See [Roots::nth_root](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#tymethod.nth_root).
pub fn nth_root(&self, n: u32) -> Self {
Roots::nth_root(self, n)
}
}

impl_sum_iter_type!(BigInt);
Expand Down
110 changes: 108 additions & 2 deletions src/biguint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ use std::ascii::AsciiExt;
#[cfg(feature = "serde")]
use serde;

use integer::Integer;
use integer::{Integer, Roots};
use traits::{ToPrimitive, FromPrimitive, Float, Num, Unsigned, CheckedAdd, CheckedSub, CheckedMul,
CheckedDiv, Zero, One};
CheckedDiv, Zero, One, pow};

use big_digit::{self, BigDigit, DoubleBigDigit};

Expand Down Expand Up @@ -1026,6 +1026,94 @@ impl Integer for BigUint {
}
}

impl Roots for BigUint {
// nth_root, sqrt and cbrt use Newton's method to compute
// principal root of a given degree for a given integer.

// Reference:
// Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.14
fn nth_root(&self, n: u32) -> Self {
assert!(n > 0, "root degree n must be at least 1");

if self.is_zero() || self.is_one() {
return self.clone()
}

match n { // Optimize for small n
1 => return self.clone(),
2 => return self.sqrt(),
3 => return self.cbrt(),
_ => (),
}

let n = n as usize;
let n_min_1 = n - 1;

let guess = BigUint::one() << (self.bits()/n + 1);

let mut u = guess;
let mut s: BigUint;

loop {
s = u;
let q = self / pow(s.clone(), n_min_1);
let t: BigUint = n_min_1 * &s + q;

u = t / n;

if u >= s { break; }
}

s
}

// Reference:
// Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
fn sqrt(&self) -> Self {
if self.is_zero() || self.is_one() {
return self.clone()
}

let guess = BigUint::one() << (self.bits()/2 + 1);

let mut u = guess;
let mut s: BigUint;

loop {
s = u;
let q = self / &s;
let t: BigUint = &s + q;
u = t >> 1;

if u >= s { break; }
}

s
}

fn cbrt(&self) -> Self {
if self.is_zero() || self.is_one() {
return self.clone()
}

let guess = BigUint::one() << (self.bits()/3 + 1);

let mut u = guess;
let mut s: BigUint;

loop {
s = u;
let q = self / (&s * &s);
let t: BigUint = (&s << 1) + q;
u = t / 3u32;

if u >= s { break; }
}

s
}
}

fn high_bits_to_u64(v: &BigUint) -> u64 {
match v.data.len() {
0 => 0,
Expand Down Expand Up @@ -1749,6 +1837,24 @@ impl BigUint {
}
acc
}

/// Returns the truncated principal square root of `self` --
/// see [Roots::sqrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.sqrt)
pub fn sqrt(&self) -> Self {
Roots::sqrt(self)
}

/// Returns the truncated principal cube root of `self` --
/// see [Roots::cbrt](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#method.cbrt).
pub fn cbrt(&self) -> Self {
Roots::cbrt(self)
}

/// Returns the truncated principal `n`th root of `self` --
/// see [Roots::nth_root](https://docs.rs/num-integer/0.1/num_integer/trait.Roots.html#tymethod.nth_root).
pub fn nth_root(&self, n: u32) -> Self {
Roots::nth_root(self, n)
}
}

/// Returns the number of least-significant bits that are zero,
Expand Down
104 changes: 104 additions & 0 deletions tests/roots.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
extern crate num_bigint;
extern crate num_integer;
extern crate num_traits;

mod biguint {
use num_bigint::BigUint;
use num_traits::pow;
use std::str::FromStr;

fn check(x: u64, n: u32) {
let big_x = BigUint::from(x);
let res = big_x.nth_root(n);

if n == 2 {
assert_eq!(&res, &big_x.sqrt())
} else if n == 3 {
assert_eq!(&res, &big_x.cbrt())
}

assert!(pow(res.clone(), n as usize) <= big_x);
assert!(pow(res.clone() + 1u32, n as usize) > big_x);
}

#[test]
fn test_sqrt() {
check(99, 2);
check(100, 2);
check(120, 2);
}

#[test]
fn test_cbrt() {
check(8, 3);
check(26, 3);
}

#[test]
fn test_nth_root() {
check(0, 1);
check(10, 1);
check(100, 4);
}

#[test]
#[should_panic]
fn test_nth_root_n_is_zero() {
check(4, 0);
}

#[test]
fn test_nth_root_big() {
let x = BigUint::from_str("123_456_789").unwrap();
let expected = BigUint::from(6u32);

assert_eq!(x.nth_root(10), expected);
}
}

mod bigint {
use num_bigint::BigInt;
use num_traits::{Signed, pow};

fn check(x: i64, n: u32) {
let big_x = BigInt::from(x);
let res = big_x.nth_root(n);

if n == 2 {
assert_eq!(&res, &big_x.sqrt())
} else if n == 3 {
assert_eq!(&res, &big_x.cbrt())
}

if big_x.is_negative() {
assert!(pow(res.clone() - 1u32, n as usize) < big_x);
assert!(pow(res.clone(), n as usize) >= big_x);
} else {
assert!(pow(res.clone(), n as usize) <= big_x);
assert!(pow(res.clone() + 1u32, n as usize) > big_x);
}
}

#[test]
fn test_nth_root() {
check(-100, 3);
}

#[test]
#[should_panic]
fn test_nth_root_x_neg_n_even() {
check(-100, 4);
}

#[test]
#[should_panic]
fn test_sqrt_x_neg() {
check(-4, 2);
}

#[test]
fn test_cbrt() {
check(8, 3);
check(-8, 3);
}
}

0 comments on commit c504fa8

Please sign in to comment.