Skip to content

Commit

Permalink
Move some shared functions to the nn module. (huggingface#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Jul 22, 2023
1 parent 43c7223 commit 1f26042
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
23 changes: 4 additions & 19 deletions candle-examples/examples/simple-training/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,12 @@
extern crate intel_mkl_src;

use anyhow::Result;
use candle::{DType, Tensor, Var, D};
use candle::{DType, Var, D};
use candle_nn::{loss, ops};

const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;

fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> candle::Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;
let diff = xs.broadcast_sub(&max)?;
let sum_exp = diff.exp()?.sum_keepdim(d)?;
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
Ok(log_sm)
}

fn nll_loss(inp: &Tensor, target: &Tensor) -> candle::Result<Tensor> {
let b_sz = target.dim(0)?;
inp.gather(target, 1)?
.sum_all()?
.affine(-1f64 / b_sz as f64, 0.)
}

pub fn main() -> Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
let m = candle_nn::vision::mnist::load_dir("data")?;
Expand All @@ -41,8 +26,8 @@ pub fn main() -> Result<()> {
let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?;
let log_sm = log_softmax(&logits, D::Minus1)?;
let loss = nll_loss(&log_sm, &train_labels)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_labels)?;
sgd.backward_step(&loss)?;

let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?;
Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pub mod embedding;
pub mod init;
pub mod layer_norm;
pub mod linear;
pub mod loss;
pub mod ops;
pub mod optim;
pub mod var_builder;
pub mod vision;
Expand Down
8 changes: 8 additions & 0 deletions candle-nn/src/loss.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use candle::{Result, Tensor};

pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let b_sz = target.dim(0)?;
inp.gather(target, 1)?
.sum_all()?
.affine(-1f64 / b_sz as f64, 0.)
}
10 changes: 10 additions & 0 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use candle::{Result, Tensor};

pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;
let diff = xs.broadcast_sub(&max)?;
let sum_exp = diff.exp()?.sum_keepdim(d)?;
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
Ok(log_sm)
}

0 comments on commit 1f26042

Please sign in to comment.