Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mean #580

Merged
merged 3 commits into from
Mar 26, 2019
Merged

Mean #580

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ serde = { version = "1.0", optional = true }
defmac = "0.2"
quickcheck = { version = "0.7.2", default-features = false }
rawpointer = "0.1"
approx = "0.3"
LukeMathWalker marked this conversation as resolved.
Show resolved Hide resolved

[features]
# Enable blas usage
Expand Down
4 changes: 2 additions & 2 deletions examples/column_standardize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ fn main() {
[ 2., 2., 2.]];

println!("{:8.4}", data);
println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)));
println!("{:8.4} (Mean axis=0)", data.mean_axis(Axis(0)).unwrap());

data -= &data.mean_axis(Axis(0));
data -= &data.mean_axis(Axis(0)).unwrap();
println!("{:8.4}", data);

data /= &std(&data, Axis(0));
Expand Down
52 changes: 43 additions & 9 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,33 @@ impl<A, S, D> ArrayBase<S, D>
sum
}

/// Returns the [arithmetic mean] x̅ of all elements in the array:
///
/// ```text
/// 1 n
/// x̅ = ― ∑ xᵢ
/// n i=1
/// ```
///
/// If the array is empty, `None` is returned.
///
/// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
///
/// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
pub fn mean(&self) -> Option<A>
where
A: Clone + FromPrimitive + Add<Output=A> + Div<Output=A> + Zero
{
let n_elements = self.len();
if n_elements == 0 {
None
} else {
let n_elements = A::from_usize(n_elements)
.expect("Converting number of elements to `A` must not fail.");
Some(self.sum() / n_elements)
}
}

/// Return the sum of all elements in the array.
///
/// *This method has been renamed to `.sum()` and will be deprecated in the
Expand Down Expand Up @@ -123,8 +150,9 @@ impl<A, S, D> ArrayBase<S, D>

/// Return mean along `axis`.
///
/// **Panics** if `axis` is out of bounds, if the length of the axis is
/// zero and division by zero panics for type `A`, or if `A::from_usize()`
/// Return `None` if the length of the axis is zero.
///
/// **Panics** if `axis` is out of bounds or if `A::from_usize()`
/// fails for the axis length.
///
/// ```
Expand All @@ -133,19 +161,25 @@ impl<A, S, D> ArrayBase<S, D>
/// let a = arr2(&[[1., 2., 3.],
/// [4., 5., 6.]]);
/// assert!(
/// a.mean_axis(Axis(0)) == aview1(&[2.5, 3.5, 4.5]) &&
/// a.mean_axis(Axis(1)) == aview1(&[2., 5.]) &&
/// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
/// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
///
/// a.mean_axis(Axis(0)).mean_axis(Axis(0)) == aview0(&3.5)
/// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
/// );
/// ```
pub fn mean_axis(&self, axis: Axis) -> Array<A, D::Smaller>
pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
where A: Clone + Zero + FromPrimitive + Add<Output=A> + Div<Output=A>,
D: RemoveAxis,
{
let n = A::from_usize(self.len_of(axis)).expect("Converting axis length to `A` must not fail.");
let sum = self.sum_axis(axis);
sum / &aview0(&n)
let axis_length = self.len_of(axis);
if axis_length == 0 {
None
} else {
let axis_length = A::from_usize(axis_length)
.expect("Converting axis length to `A` must not fail.");
let sum = self.sum_axis(axis);
Some(sum / &aview0(&axis_length))
}
}

/// Return variance along `axis`.
Expand Down
169 changes: 0 additions & 169 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -925,175 +925,6 @@ fn assign()
assert_eq!(a, arr2(&[[0, 0], [3, 4]]));
}

#[test]
fn sum_mean()
{
let a = arr2(&[[1., 2.], [3., 4.]]);
assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.]));
assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.]));
assert_eq!(a.mean_axis(Axis(0)), arr1(&[2., 3.]));
assert_eq!(a.mean_axis(Axis(1)), arr1(&[1.5, 3.5]));
assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.));
assert_eq!(a.view().mean_axis(Axis(1)), aview1(&[1.5, 3.5]));
assert_eq!(a.sum(), 10.);
}

#[test]
fn sum_mean_empty() {
assert_eq!(Array3::<f32>::ones((2, 0, 3)).sum(), 0.);
assert_eq!(Array1::<f32>::ones(0).sum_axis(Axis(0)), arr0(0.));
assert_eq!(
Array3::<f32>::ones((2, 0, 3)).sum_axis(Axis(1)),
Array::zeros((2, 3)),
);
let a = Array1::<f32>::ones(0).mean_axis(Axis(0));
assert_eq!(a.shape(), &[]);
assert!(a[()].is_nan());
let a = Array3::<f32>::ones((2, 0, 3)).mean_axis(Axis(1));
assert_eq!(a.shape(), &[2, 3]);
a.mapv(|x| assert!(x.is_nan()));
}

#[test]
fn var_axis() {
let a = array![
[
[-9.76, -0.38, 1.59, 6.23],
[-8.57, -9.27, 5.76, 6.01],
[-9.54, 5.09, 3.21, 6.56],
],
[
[ 8.23, -9.63, 3.76, -3.48],
[-5.46, 5.86, -2.81, 1.35],
[-1.08, 4.66, 8.34, -0.73],
],
];
assert!(a.var_axis(Axis(0), 1.5).all_close(
&aview2(&[
[3.236401e+02, 8.556250e+01, 4.708900e+00, 9.428410e+01],
[9.672100e+00, 2.289169e+02, 7.344490e+01, 2.171560e+01],
[7.157160e+01, 1.849000e-01, 2.631690e+01, 5.314410e+01]
]),
1e-4,
));
assert!(a.var_axis(Axis(1), 1.7).all_close(
&aview2(&[
[0.61676923, 80.81092308, 6.79892308, 0.11789744],
[75.19912821, 114.25235897, 48.32405128, 9.03020513],
]),
1e-8,
));
assert!(a.var_axis(Axis(2), 2.3).all_close(
&aview2(&[
[ 79.64552941, 129.09663235, 95.98929412],
[109.64952941, 43.28758824, 36.27439706],
]),
1e-8,
));

let b = array![[1.1, 2.3, 4.7]];
assert!(b.var_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
assert!(b.var_axis(Axis(1), 0.).all_close(&aview1(&[2.24]), 1e-12));

let c = array![[], []];
assert_eq!(c.var_axis(Axis(0), 0.), aview1(&[]));

let d = array![1.1, 2.7, 3.5, 4.9];
assert!(d.var_axis(Axis(0), 0.).all_close(&aview0(&1.8875), 1e-12));
}

#[test]
fn std_axis() {
let a = array![
[
[ 0.22935481, 0.08030619, 0.60827517, 0.73684379],
[ 0.90339851, 0.82859436, 0.64020362, 0.2774583 ],
[ 0.44485313, 0.63316367, 0.11005111, 0.08656246]
],
[
[ 0.28924665, 0.44082454, 0.59837736, 0.41014531],
[ 0.08382316, 0.43259439, 0.1428889 , 0.44830176],
[ 0.51529756, 0.70111616, 0.20799415, 0.91851457]
],
];
assert!(a.std_axis(Axis(0), 1.5).all_close(
&aview2(&[
[ 0.05989184, 0.36051836, 0.00989781, 0.32669847],
[ 0.81957535, 0.39599997, 0.49731472, 0.17084346],
[ 0.07044443, 0.06795249, 0.09794304, 0.83195211],
]),
1e-4,
));
assert!(a.std_axis(Axis(1), 1.7).all_close(
&aview2(&[
[ 0.42698655, 0.48139215, 0.36874991, 0.41458724],
[ 0.26769097, 0.18941435, 0.30555015, 0.35118674],
]),
1e-8,
));
assert!(a.std_axis(Axis(2), 2.3).all_close(
&aview2(&[
[ 0.41117907, 0.37130425, 0.35332388],
[ 0.16905862, 0.25304841, 0.39978276],
]),
1e-8,
));

let b = array![[100000., 1., 0.01]];
assert!(b.std_axis(Axis(0), 0.).all_close(&aview1(&[0., 0., 0.]), 1e-12));
assert!(
b.std_axis(Axis(1), 0.).all_close(&aview1(&[47140.214021552769]), 1e-6),
);

let c = array![[], []];
assert_eq!(c.std_axis(Axis(0), 0.), aview1(&[]));
}

#[test]
#[should_panic]
fn var_axis_negative_ddof() {
let a = array![1., 2., 3.];
a.var_axis(Axis(0), -1.);
}

#[test]
#[should_panic]
fn var_axis_too_large_ddof() {
let a = array![1., 2., 3.];
a.var_axis(Axis(0), 4.);
}

#[test]
fn var_axis_nan_ddof() {
let a = Array2::<f64>::zeros((2, 3));
let v = a.var_axis(Axis(1), ::std::f64::NAN);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
fn var_axis_empty_axis() {
let a = Array2::<f64>::zeros((2, 0));
let v = a.var_axis(Axis(1), 0.);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
#[should_panic]
fn std_axis_bad_dof() {
let a = array![1., 2., 3.];
a.std_axis(Axis(0), 4.);
}

#[test]
fn std_axis_empty_axis() {
let a = Array2::<f64>::zeros((2, 0));
let v = a.std_axis(Axis(1), 0.);
assert_eq!(v.shape(), &[2]);
v.mapv(|x| assert!(x.is_nan()));
}

#[test]
fn iter_size_hint()
{
Expand Down
2 changes: 1 addition & 1 deletion tests/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ fn complex_mat_mul()
let r = a.dot(&e);
println!("{}", a);
assert_eq!(r, a);
assert_eq!(a.mean_axis(Axis(0)), arr1(&[c(1.5, 1.), c(2.5, 0.)]));
assert_eq!(a.mean_axis(Axis(0)).unwrap(), arr1(&[c(1.5, 1.), c(2.5, 0.)]));
}
Loading