Skip to content

Commit

Permalink
prevent division by zero in bincrossentropy
Browse files Browse the repository at this point in the history
  • Loading branch information
retraigo committed Dec 9, 2023
1 parent 6eed862 commit 3d1c3c7
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions crates/core/src/cpu/cost.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{ops::{Div, Mul, Sub}, f32::EPSILON};
use std::{
f32::EPSILON,
ops::{Add, Div, Mul, Sub},
};

use ndarray::{s, ArrayD, ArrayViewD};

Expand Down Expand Up @@ -46,7 +49,14 @@ fn cross_entropy<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) -> f32
let batches = y_hat.dim()[0];
let mut total = 0.0;
for b in 0..batches {
total -= &y.slice(s![b, ..]).mul(&y_hat.slice(s![b, ..]).map(|x| x.max(EPSILON).min(1f32 - EPSILON).ln())).sum()
total -= &y
.slice(s![b, ..])
.mul(
&y_hat
.slice(s![b, ..])
.map(|x| x.max(EPSILON).min(1f32 - EPSILON).ln()),
)
.sum()
}
return total / batches as f32;
}
Expand All @@ -64,7 +74,7 @@ fn bin_cross_entropy<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) ->
}

fn bin_cross_entropy_prime<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) -> ArrayD<f32> {
return y.sub(&y_hat).div(y.mul(1.0.sub(&y)));
return y.sub(&y_hat).div(y.mul(1.0.sub(&y)).add(EPSILON));
}

fn hinge<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) -> f32 {
Expand Down

0 comments on commit 3d1c3c7

Please sign in to comment.