Skip to content

Commit

Permalink
avoid log(0) and div by 0
Browse files Browse the repository at this point in the history
  • Loading branch information
retraigo committed Dec 9, 2023
1 parent 36750a8 commit 3756045
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions crates/core/src/cpu/cost.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Div, Mul, Sub};
use std::{ops::{Div, Mul, Sub}, f32::EPSILON};

use ndarray::{s, ArrayD, ArrayViewD};

Expand Down Expand Up @@ -38,26 +38,27 @@ fn mse<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) -> f32 {
}

fn mse_prime<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) -> ArrayD<f32> {
//println!("{:?} - {:?}", y, y_hat);
return y.sub(&y_hat);
}

fn cross_entropy<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) -> f32 {
let batches = y_hat.dim()[0];
let mut total = 0.0;
for b in 0..batches {
total -= y_hat.slice(s![b, ..]).mul(&y.slice(s![b, ..])).sum().ln()
total -= &y.slice(s![b, ..]).mul(&y_hat.slice(s![b, ..]).map(|x| x.max(EPSILON).ln())).sum()
}
return total / batches as f32;
}

fn cross_entropy_prime<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) -> ArrayD<f32> {
return -y_hat.div(&y);
return -y_hat.div(&y.map(|x| x.max(EPSILON)));
}

fn bin_cross_entropy<'a>(y_hat: ArrayViewD<'a, f32>, y: ArrayViewD<'a, f32>) -> f32 {
return -y_hat
.mul(y.map(|x| x.ln()))
.sub(((1.0).sub(&y_hat)).mul(y.map(|x| 1.0 - x.ln())))
.mul(y.map(|x| x.max(EPSILON).min(1f32 - EPSILON).ln()))
.sub(((1.0).sub(&y_hat)).mul(y.map(|x| 1.0 - x.max(EPSILON).ln())))
.sum()
/ y.len() as f32;
}
Expand Down

0 comments on commit 3756045

Please sign in to comment.