Skip to content

Commit

Permalink
Move numeric tests into the numeric test submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeMathWalker authored and jturner314 committed Mar 26, 2019
1 parent 6e37c46 commit 64b3da7
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 169 deletions.
167 changes: 0 additions & 167 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -925,173 +925,6 @@ fn assign()
assert_eq!(a, arr2(&[[0, 0], [3, 4]]));
}

#[test]
fn sum_mean()
{
let a = arr2(&[[1., 2.], [3., 4.]]);
assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.]));
assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.]));
assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.])));
assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5])));
assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.));
assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5]));
assert_eq!(a.sum(), 10.);
}

#[test]
fn sum_mean_empty() {
assert_eq!(Array3::<f32>::ones((2, 0, 3)).sum(), 0.);
assert_eq!(Array1::<f32>::ones(0).sum_axis(Axis(0)), arr0(0.));
assert_eq!(
Array3::<f32>::ones((2, 0, 3)).sum_axis(Axis(1)),
Array::zeros((2, 3)),
);
let a = Array1::<f32>::ones(0).mean_axis(Axis(0));
assert_eq!(a, None);
let a = Array3::<f32>::ones((2, 0, 3)).mean_axis(Axis(1));
assert_eq!(a, None);
}

#[test]
fn var_axis() {
let a = array![
[
[-9.76, -0.38, 1.59, 6.23],
[-8.57, -9.27, 5.76, 6.01],
[-9.54, 5.09, 3.21, 6.56],
],
[
[ 8.23, -9.63, 3.76, -3.48],
[-5.46, 5.86, -2.81, 1.35],
[-1.08, 4.66, 8.34, -0.73],
],
];
assert!(a.var_axis(Axis(0), 1.5).all_close(
&aview2(&[
[3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01],
[9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01],
[7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01]
]),
1e-4,
));
assert!(a.var_axis(Axis(1), 1.7).all_close(
&aview2(&[
[0.61676923, 80.81092308, 6.79892308, 0.11789744],
[75.19912821, 114.25235897, 48.32405128, 9.03020513],
]),
1e-8,
));
assert!(a.var_axis(Axis(2), 2.3).all_close(
&aview2(&[
[ 79.64552941, 129.09663235, 95.98929412],
[109.64952941, 43.28758824, 36.27439706],
]),
1e-8,
));

let b = array![[1.1, 2.3, 4.7]];
assert!(b.var_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
assert!(b.var_axis(Axis(1), 0.).all_close(&aview1(&[2.24]), 1e-12));

let c = array![[], []];
assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[]));

let d = array![1.1, 2.7, 3.5, 4.9];
assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12));
}

#[test]
fn std_axis() {
let a = array![
[
[ 0.22935481, 0.08030619, 0.60827517, 0.73684379],
[ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ],
[ 0.44485313, 0.63316367, 0.11005111, 0.08656246]
],
[
[ 0.28924665, 0.44082454, 0.59837736, 0.41014531],
[ 0.08382316, 0.43259439, 0.1428889 , 0.44830176],
[ 0.51529756, 0.70111616, 0.20799415, 0.91851457]
],
];
assert!(a.std_axis(Axis(0), 1.5).all_close(
&aview2(&[
[ 0.05989184, 0.36051836, 0.00989781, 0.32669847],
[ 0.81957535, 0.39599997, 0.49731472, 0.17084346],
[ 0.07044443, 0.06795249, 0.09794304, 0.83195211],
]),
1e-4,
));
assert!(a.std_axis(Axis(1), 1.7).all_close(
&aview2(&[
[ 0.42698655, 0.48139215, 0.36874991, 0.41458724],
[ 0.26769097, 0.18941435, 0.30555015, 0.35118674],
]),
1e-8,
));
assert!(a.std_axis(Axis(2), 2.3).all_close(
&aview2(&[
[ 0.41117907, 0.37130425, 0.35332388],
[ 0.16905862, 0.25304841, 0.39978276],
]),
1e-8,
));

let b = array![[100000., 1., 0.01]];
assert!(b.std_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
assert!(
b.std_axis(Axis(1), 0.).all_close(&aview1(&[47140.214021552769]), 1e-6),
);

let c = array![[], []];
assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[]));
}

#[test]
#[should_panic]
fn var_axis_negative_ddof() {
let a = array![1., 2., 3.];
a.var_axis(Axis(0), -1.);
}

#[test]
#[should_panic]
fn var_axis_too_large_ddof() {
let a = array![1., 2., 3.];
a.var_axis(Axis(0), 4.);
}

#[test]
fn var_axis_nan_ddof() {
let a = Array2::<f64>::zeros((2, 3));
let v = a.var_axis(Axis(1), ::std::f64::NAN);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
fn var_axis_empty_axis() {
let a = Array2::<f64>::zeros((2, 0));
let v = a.var_axis(Axis(1), 0.);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
#[should_panic]
fn std_axis_bad_dof() {
let a = array![1., 2., 3.];
a.std_axis(Axis(0), 4.);
}

#[test]
fn std_axis_empty_axis() {
let a = Array2::<f64>::zeros((2, 0));
let v = a.std_axis(Axis(1), 0.);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
fn iter_size_hint()
{
Expand Down
172 changes: 170 additions & 2 deletions tests/numeric.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
extern crate approx;
use std::f64;
use ndarray::{Array1, array};
use ndarray::{array, Axis, aview1, aview2, aview0, arr0, arr1, arr2, Array, Array1, Array2, Array3};
use approx::abs_diff_eq;

#[test]
Expand Down Expand Up @@ -32,4 +32,172 @@ fn test_mean_with_array_of_floats() {
// Computed using NumPy
let expected_mean = 0.5475494059146699;
abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = f64::EPSILON);
}
}

#[test]
fn sum_mean()
{
let a = arr2(&[[1., 2.], [3., 4.]]);
assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.]));
assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.]));
assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.])));
assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5])));
assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.));
assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5]));
assert_eq!(a.sum(), 10.);
}

#[test]
fn sum_mean_empty() {
assert_eq!(Array3::<f32>::ones((2, 0, 3)).sum(), 0.);
assert_eq!(Array1::<f32>::ones(0).sum_axis(Axis(0)), arr0(0.));
assert_eq!(
Array3::<f32>::ones((2, 0, 3)).sum_axis(Axis(1)),
Array::zeros((2, 3)),
);
let a = Array1::<f32>::ones(0).mean_axis(Axis(0));
assert_eq!(a, None);
let a = Array3::<f32>::ones((2, 0, 3)).mean_axis(Axis(1));
assert_eq!(a, None);
}

#[test]
fn var_axis() {
let a = array![
[
[-9.76, -0.38, 1.59, 6.23],
[-8.57, -9.27, 5.76, 6.01],
[-9.54, 5.09, 3.21, 6.56],
],
[
[ 8.23, -9.63, 3.76, -3.48],
[-5.46, 5.86, -2.81, 1.35],
[-1.08, 4.66, 8.34, -0.73],
],
];
assert!(a.var_axis(Axis(0), 1.5).all_close(
&aview2(&[
[3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01],
[9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01],
[7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01]
]),
1e-4,
));
assert!(a.var_axis(Axis(1), 1.7).all_close(
&aview2(&[
[0.61676923, 80.81092308, 6.79892308, 0.11789744],
[75.19912821, 114.25235897, 48.32405128, 9.03020513],
]),
1e-8,
));
assert!(a.var_axis(Axis(2), 2.3).all_close(
&aview2(&[
[ 79.64552941, 129.09663235, 95.98929412],
[109.64952941, 43.28758824, 36.27439706],
]),
1e-8,
));

let b = array![[1.1, 2.3, 4.7]];
assert!(b.var_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
assert!(b.var_axis(Axis(1), 0.).all_close(&aview1(&[2.24]), 1e-12));

let c = array![[], []];
assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[]));

let d = array![1.1, 2.7, 3.5, 4.9];
assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12));
}

#[test]
fn std_axis() {
let a = array![
[
[ 0.22935481, 0.08030619, 0.60827517, 0.73684379],
[ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ],
[ 0.44485313, 0.63316367, 0.11005111, 0.08656246]
],
[
[ 0.28924665, 0.44082454, 0.59837736, 0.41014531],
[ 0.08382316, 0.43259439, 0.1428889 , 0.44830176],
[ 0.51529756, 0.70111616, 0.20799415, 0.91851457]
],
];
assert!(a.std_axis(Axis(0), 1.5).all_close(
&aview2(&[
[ 0.05989184, 0.36051836, 0.00989781, 0.32669847],
[ 0.81957535, 0.39599997, 0.49731472, 0.17084346],
[ 0.07044443, 0.06795249, 0.09794304, 0.83195211],
]),
1e-4,
));
assert!(a.std_axis(Axis(1), 1.7).all_close(
&aview2(&[
[ 0.42698655, 0.48139215, 0.36874991, 0.41458724],
[ 0.26769097, 0.18941435, 0.30555015, 0.35118674],
]),
1e-8,
));
assert!(a.std_axis(Axis(2), 2.3).all_close(
&aview2(&[
[ 0.41117907, 0.37130425, 0.35332388],
[ 0.16905862, 0.25304841, 0.39978276],
]),
1e-8,
));

let b = array![[100000., 1., 0.01]];
assert!(b.std_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
assert!(
b.std_axis(Axis(1), 0.).all_close(&aview1(&[47140.214021552769]), 1e-6),
);

let c = array![[], []];
assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[]));
}

#[test]
#[should_panic]
fn var_axis_negative_ddof() {
let a = array![1., 2., 3.];
a.var_axis(Axis(0), -1.);
}

#[test]
#[should_panic]
fn var_axis_too_large_ddof() {
let a = array![1., 2., 3.];
a.var_axis(Axis(0), 4.);
}

#[test]
fn var_axis_nan_ddof() {
let a = Array2::<f64>::zeros((2, 3));
let v = a.var_axis(Axis(1), ::std::f64::NAN);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
fn var_axis_empty_axis() {
let a = Array2::<f64>::zeros((2, 0));
let v = a.var_axis(Axis(1), 0.);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
#[should_panic]
fn std_axis_bad_dof() {
let a = array![1., 2., 3.];
a.std_axis(Axis(0), 4.);
}

#[test]
fn std_axis_empty_axis() {
let a = Array2::<f64>::zeros((2, 0));
let v = a.std_axis(Axis(1), 0.);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

0 comments on commit 64b3da7

Please sign in to comment.