diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index 3ec813056..ee49607e3 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -13,25 +13,20 @@ fn gelu_scalar(x: F) -> F { } pub fn launch(device: &R::Device) { - type Primitive = half::f16; - type CubeType = F16; - let client = R::client(device); - let input = &[-1., 0., 1., 5.].map(|f| Primitive::from_f32(f)); - - let output_handle = client.empty(input.len() * core::mem::size_of::()); - let input_handle = client.create(Primitive::as_bytes(input)); + let input = &[-1., 0., 1., 5.]; + let output_handle = client.empty(input.len() * core::mem::size_of::()); - gelu_array::launch::( + gelu_array::launch::( client.clone(), CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32 / 4, 1, 1), - ArrayArg::vectorized(4, &input_handle, input.len()), - ArrayArg::vectorized(4, &output_handle, input.len()), + CubeDim::new(input.len() as u32, 1, 1), + ArrayArg::new(&client.create(f32::as_bytes(input)), input.len()), + ArrayArg::new(&output_handle, input.len()), ); let bytes = client.read(output_handle.binding()); - let output = Primitive::from_bytes(&bytes); + let output = f32::from_bytes(&bytes); // Should be [-0.1587, 0.0000, 0.8413, 5.0000] println!("Executed gelu with runtime {:?} => {output:?}", R::name());