Skip to content

Commit

Permalink
Restore example
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jul 18, 2024
1 parent c0aec5e commit 2001415
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions examples/gelu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,20 @@ fn gelu_scalar<F: Float>(x: F) -> F {
}

pub fn launch<R: Runtime>(device: &R::Device) {
type Primitive = half::f16;
type CubeType = F16;

let client = R::client(device);
let input = &[-1., 0., 1., 5.].map(|f| Primitive::from_f32(f));

let output_handle = client.empty(input.len() * core::mem::size_of::<Primitive>());
let input_handle = client.create(Primitive::as_bytes(input));
let input = &[-1., 0., 1., 5.];
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());

gelu_array::launch::<CubeType, R>(
gelu_array::launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32 / 4, 1, 1),
ArrayArg::vectorized(4, &input_handle, input.len()),
ArrayArg::vectorized(4, &output_handle, input.len()),
CubeDim::new(input.len() as u32, 1, 1),
ArrayArg::new(&client.create(f32::as_bytes(input)), input.len()),
ArrayArg::new(&output_handle, input.len()),
);

let bytes = client.read(output_handle.binding());
let output = Primitive::from_bytes(&bytes);
let output = f32::from_bytes(&bytes);

// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
println!("Executed gelu with runtime {:?} => {output:?}", R::name());
Expand Down

0 comments on commit 2001415

Please sign in to comment.