Skip to content

Commit

Permalink
return array in std_mean (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba authored Jan 9, 2025
1 parent bbc9ada commit ac71d54
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions crates/kornia-imgproc/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use rayon::{
/// assert_eq!(std, [93.5183805462862, 93.5183805462862, 93.5183805462862]);
/// assert_eq!(mean, [111.25, 112.25, 113.25]);
/// ```
pub fn std_mean(image: &Image<u8, 3>) -> (Vec<f64>, Vec<f64>) {
pub fn std_mean(image: &Image<u8, 3>) -> ([f64; 3], [f64; 3]) {
let (sum, sq_sum) = image.as_slice().chunks_exact(3).fold(
([0f64; 3], [0f64; 3]),
|(mut sum, mut sq_sum), pixel| {
Expand All @@ -55,13 +55,13 @@ pub fn std_mean(image: &Image<u8, 3>) -> (Vec<f64>, Vec<f64>) {
);

let n = (image.width() * image.height()) as f64;
let mean = sum.iter().map(|&s| s / n).collect::<Vec<_>>();
let mean = [sum[0] / n, sum[1] / n, sum[2] / n];

let variance = sq_sum
.iter()
.zip(mean.iter())
.map(|(&sq_s, &m)| (sq_s / n - m.powi(2)).sqrt())
.collect::<Vec<_>>();
let variance = [
(sq_sum[0] / n - mean[0].powi(2)).sqrt(),
(sq_sum[1] / n - mean[1].powi(2)).sqrt(),
(sq_sum[2] / n - mean[2].powi(2)).sqrt(),
];

(variance, mean)
}
Expand Down Expand Up @@ -268,9 +268,12 @@ mod tests {
vec![0, 1, 2, 253, 254, 255, 128, 129, 130, 64, 65, 66],
)?;

let std_expected = [93.5183805462862, 93.5183805462862, 93.5183805462862];
let mean_expected = [111.25, 112.25, 113.25];

let (std, mean) = super::std_mean(&image);
assert_eq!(std, [93.5183805462862, 93.5183805462862, 93.5183805462862]);
assert_eq!(mean, [111.25, 112.25, 113.25]);
assert_eq!(std, std_expected);
assert_eq!(mean, mean_expected);
Ok(())
}

Expand Down

0 comments on commit ac71d54

Please sign in to comment.