Skip to content

Commit

Permalink
Use custom impl and time it
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent 71bd43a commit 8d9218d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ impl crate::Module for LayerNorm {
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
let func = dev.get_or_load_func("rms_f32", kernels::LAYERNORM_KERNELS)?;
let params = (&dst, &slice, self.eps, d1, d2);
let cfg = LaunchConfig {
grid_dim: (d1, 1, 1),
Expand Down

0 comments on commit 8d9218d

Please sign in to comment.