From c5b77d32eb87729459955bc4ea8e844c3b661b1f Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 10 Nov 2023 10:49:33 +0100 Subject: [PATCH] simplify threads/threadgroups calculation --- candle-metal-kernels/Cargo.toml | 10 +- candle-metal-kernels/examples/affine.rs | 23 +++- candle-metal-kernels/examples/binary.rs | 48 +++++++- candle-metal-kernels/examples/cast.rs | 22 +++- candle-metal-kernels/examples/reduce.rs | 102 ++++++++++++++++ candle-metal-kernels/examples/ternary.rs | 147 +++++++++++++++++++++++ candle-metal-kernels/examples/unary.rs | 47 +++++++- candle-metal-kernels/src/affine.metal | 12 +- candle-metal-kernels/src/binary.metal | 32 ++--- candle-metal-kernels/src/cast.metal | 28 ++--- candle-metal-kernels/src/lib.rs | 41 ++++--- candle-metal-kernels/src/ternary.metal | 55 +++++---- candle-metal-kernels/src/unary.metal | 28 ++--- 13 files changed, 461 insertions(+), 134 deletions(-) create mode 100644 candle-metal-kernels/examples/reduce.rs create mode 100644 candle-metal-kernels/examples/ternary.rs diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index d6bc670657..46e2d4a15a 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -17,12 +17,4 @@ tracing = "0.1.37" [dev-dependencies] half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } -rand = "0.8.5" - -[[example]] -name = "unary" -path = "examples/unary.rs" - -[[example]] -name = "binary" -path = "examples/binary.rs" +rand = "0.8.5" \ No newline at end of file diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/examples/affine.rs index b8005dc0ae..c028edec6c 100644 --- a/candle-metal-kernels/examples/affine.rs +++ b/candle-metal-kernels/examples/affine.rs @@ -18,7 +18,7 @@ fn main() { .collect::>(); println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5:}", "dtype", "kernel", "size", "runs", "total time", "avg time" ); @@ -42,6 +42,25 @@ fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { let mul: f32 = 1.2345; let add: f32 = 2.3456; + + // Ghost pass to ensure kernel load time is not included in benchmarks + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_affine( + &device, + command_buffer, + &kernels, + v.len(), + &input, + &mut output, + mul, + add, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); let start = Instant::now(); @@ -64,7 +83,7 @@ fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { start.elapsed() }); println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", type_name::().split("::").last().unwrap(), "affine", v.len(), diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/examples/binary.rs index af5a8bdc62..d51d1672c1 100644 --- a/candle-metal-kernels/examples/binary.rs +++ b/candle-metal-kernels/examples/binary.rs @@ -66,7 +66,7 @@ fn main() { ]; println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5}", "dtype", "kernel", "size", "runs", "total time", "avg time" ); @@ -105,6 +105,25 @@ fn run_binary_bench( let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); // Contiguous + // Ghost pass to ensure kernel load time is not included in benchmarks + for kernel_name in contiguous { + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_binary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + } for kernel_name in contiguous { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); @@ -128,7 +147,7 @@ fn run_binary_bench( start.elapsed() }); println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", type_name::().split("::").last().unwrap(), kernel_name.to_string(), v.len(), @@ -142,6 +161,29 @@ fn run_binary_bench( let shape = vec![2, 5_000]; let strides = vec![2, 1]; let offset = 0; + // Ghost pass to ensure kernel load time is not included in benchmarks + for kernel_name in strided { + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_binary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &input, + &strides, + offset, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + } for kernel_name in strided { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); @@ -170,7 +212,7 @@ fn run_binary_bench( }); println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", type_name::().split("::").last().unwrap(), kernel_name.to_string(), v.len(), diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/examples/cast.rs index 090f510d16..26a0e0c188 100644 --- a/candle-metal-kernels/examples/cast.rs +++ b/candle-metal-kernels/examples/cast.rs @@ -20,7 +20,7 @@ fn main() { let contiguous_kernels = ["cast_u32_f32"]; println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5:}", "dtype", "kernel", "size", "runs", "total time", "avg time" ); @@ -48,6 +48,24 @@ fn run_cast_bench( let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); // Contiguous + // Ghost pass to ensure kernel load time is not included in benchmarks + for kernel_name in contiguous { + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_cast_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + } for kernel_name in contiguous { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); @@ -70,7 +88,7 @@ fn run_cast_bench( start.elapsed() }); println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", type_name::().split("::").last().unwrap(), kernel_name.to_string(), v.len(), diff --git a/candle-metal-kernels/examples/reduce.rs b/candle-metal-kernels/examples/reduce.rs new file mode 100644 index 0000000000..076035145a --- /dev/null +++ b/candle-metal-kernels/examples/reduce.rs @@ -0,0 +1,102 @@ +use candle_metal_kernels::{call_reduce_contiguous, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let reduce_kernels = ["fast_sum_float", "softmax_float"]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_reduce_bench(&device, &kernels, &f32_1k, reduce_kernels); + run_reduce_bench(&device, &kernels, &f32_10k, reduce_kernels); + run_reduce_bench(&device, &kernels, &f32_100k, reduce_kernels); +} + +fn run_reduce_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + reduce_kernels: [&'static str; 2], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + // Ghost pass to ensure kernel load time is not included in benchmarks + for kernel_name in reduce_kernels { + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_reduce_contiguous( + &device, + command_buffer, + &kernels, + kernel_name, + v.len(), + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + } + for kernel_name in reduce_kernels { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_reduce_contiguous( + &device, + command_buffer, + &kernels, + kernel_name, + v.len(), + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/candle-metal-kernels/examples/ternary.rs b/candle-metal-kernels/examples/ternary.rs new file mode 100644 index 0000000000..71308ee0e7 --- /dev/null +++ b/candle-metal-kernels/examples/ternary.rs @@ -0,0 +1,147 @@ +use candle_metal_kernels::{call_where_cond_strided, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f32_strided_k = [ + "where_i64_f32", + "where_u32_f32", + "where_u8_f32", + "where_i64_f16", + "where_u32_f16", + "where_u8_f16", + "where_i64_u8", + "where_u32_u8", + "where_u8_u8", + "where_i64_u32", + "where_u32_u32", + "where_u8_u32", + "where_i64_i64", + "where_u32_i64", + "where_u8_i64", + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5:}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_ternary_bench(&device, &kernels, &f32_1k, f32_strided_k); + run_ternary_bench(&device, &kernels, &f32_10k, f32_strided_k); + run_ternary_bench(&device, &kernels, &f32_100k, f32_strided_k); +} + +fn run_ternary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + // contiguous: [&'static str; 1], + strided: [&'static str; 15], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + let iterations = 1000; + + let length = v.len() as i32; + + let shape = vec![2, v.len() / 2]; + + // Stride is the biggest factor to the complexity of the kernel + let stride: Vec<_> = (0..v.len() / 50).map(|i| i).collect(); + let cond = (0..length).map(|i| i % 2).collect::>(); + let left = (0..length).step_by(1).collect::>(); + let right = left.iter().map(|v| -*v).collect::>(); + + let cond = device.new_buffer_with_data( + cond.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&cond) as u64, + options, + ); + let left = device.new_buffer_with_data( + left.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&left) as u64, + options, + ); + let right = device.new_buffer_with_data( + right.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&right) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Ghost pass to ensure kernel load time is not included in benchmarks + for kernel_name in strided { + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_where_cond_strided( + device, + command_buffer, + &kernels, + &kernel_name, + &shape, + &cond, + (stride.as_slice(), 0), + &left, + (stride.as_slice(), 0), + &right, + (stride.as_slice(), 0), + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + } + for kernel_name in strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_where_cond_strided( + device, + command_buffer, + &kernels, + &kernel_name, + &shape, + &cond, + (stride.as_slice(), 0), + &left, + (stride.as_slice(), 0), + &right, + (stride.as_slice(), 0), + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/examples/unary.rs index 7039c0985a..09822a460f 100644 --- a/candle-metal-kernels/examples/unary.rs +++ b/candle-metal-kernels/examples/unary.rs @@ -84,7 +84,7 @@ fn main() { ]; println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5:}", "dtype", "kernel", "size", "runs", "total time", "avg time" ); @@ -114,7 +114,7 @@ fn run_unary_bench( let command_queue = device.new_command_queue(); let options = MTLResourceOptions::StorageModeManaged; - let iterations = 10000; + let iterations = 1000; let input = device.new_buffer_with_data( v.as_ptr() as *const core::ffi::c_void, core::mem::size_of_val(v) as u64, @@ -123,6 +123,24 @@ fn run_unary_bench( let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); // Contiguous + // Ghost pass to ensure kernel load time is not included in benchmarks + for kernel_name in contiguous { + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_unary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + } for kernel_name in contiguous { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); @@ -145,7 +163,7 @@ fn run_unary_bench( start.elapsed() }); println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", type_name::().split("::").last().unwrap(), kernel_name.to_string(), v.len(), @@ -159,6 +177,27 @@ fn run_unary_bench( let shape = vec![2, 5_000]; let strides = vec![2, 1]; let offset = 0; + // Ghost pass to ensure kernel load time is not included in benchmarks + for kernel_name in strided { + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_unary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &mut output, + 0, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + } for kernel_name in strided { let total_time = autoreleasepool(|| { let command_buffer = command_queue.new_command_buffer(); @@ -185,7 +224,7 @@ fn run_unary_bench( }); println!( - "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", type_name::().split("::").last().unwrap(), kernel_name.to_string(), v.len(), diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 54729624b2..48413cfb5c 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -24,18 +24,12 @@ kernel void FN_NAME( \ constant float &add, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_position_in_grid [[ threadgroup_position_in_grid ]], \ - uint thread_position_in_threadgroup [[ thread_position_in_threadgroup ]], \ - uint threads_per_threadgroup [[ threads_per_threadgroup ]] \ + uint gid [[thread_position_in_grid]] \ ) { \ - uint thread_position_in_grid = (threadgroup_position_in_grid * threads_per_threadgroup) + \ - thread_position_in_threadgroup; \ - if (thread_position_in_grid >= dim) { \ - return; \ - } \ + if (gid >= dim) return; \ const TYPENAME m = TYPENAME(mul); \ const TYPENAME a = TYPENAME(add); \ - output[thread_position_in_grid] = input[thread_position_in_grid] * m + a; \ + output[gid] = input[gid] * m + a; \ } \ AFFINE(affine_float, float) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index a3e636f559..600e01c097 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -23,18 +23,12 @@ kernel void FN_NAME( \ device const TYPENAME *left, \ device const TYPENAME *right, \ device TYPENAME *output, \ - uint threadgroup_position_in_grid [[ threadgroup_position_in_grid ]], \ - uint thread_position_in_threadgroup [[ thread_position_in_threadgroup ]], \ - uint threads_per_threadgroup [[ threads_per_threadgroup ]] \ + uint gid [[thread_position_in_grid]] \ ) { \ - uint thread_position_in_grid = (threadgroup_position_in_grid * threads_per_threadgroup) + \ - thread_position_in_threadgroup; \ - if (thread_position_in_grid >= dim) { \ - return; \ - } \ - TYPENAME x = left[thread_position_in_grid]; \ - TYPENAME y = right[thread_position_in_grid]; \ - output[thread_position_in_grid] = OUT_TYPENAME(FN); \ + if (gid >= dim) return; \ + TYPENAME x = left[gid]; \ + TYPENAME y = right[gid]; \ + output[gid] = OUT_TYPENAME(FN); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -45,18 +39,12 @@ kernel void FN_NAME_STRIDED( \ device const TYPENAME *left, \ device const TYPENAME *right, \ device TYPENAME *output, \ - uint threadgroup_position_in_grid [[ threadgroup_position_in_grid ]], \ - uint thread_position_in_threadgroup [[ thread_position_in_threadgroup ]], \ - uint threads_per_threadgroup [[ threads_per_threadgroup ]] \ + uint gid [[thread_position_in_grid]] \ ) { \ - uint thread_position_in_grid = (threadgroup_position_in_grid * threads_per_threadgroup) + \ - thread_position_in_threadgroup; \ - if (thread_position_in_grid >= dim) { \ - return; \ - } \ - TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ - TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ - output[thread_position_in_grid] = OUT_TYPENAME(FN); \ + if (gid >= dim) return; \ + TYPENAME x = left[get_strided_index(gid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(gid, num_dims, dims, left_strides)]; \ + output[gid] = OUT_TYPENAME(FN); \ } #define BINARY_OP(FN, NAME) \ diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 09016ccd18..6fc905516a 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -23,16 +23,10 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint threadgroup_position_in_grid [[ threadgroup_position_in_grid ]], \ - uint thread_position_in_threadgroup [[ thread_position_in_threadgroup ]], \ - uint threads_per_threadgroup [[ threads_per_threadgroup ]]) \ -{ \ - uint thread_position_in_grid = (threadgroup_position_in_grid * threads_per_threadgroup) + \ - thread_position_in_threadgroup; \ - if (thread_position_in_grid >= dim) { \ - return; \ - } \ - output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \ + uint gid [[thread_position_in_grid]] \ +) { \ + if (gid >= dim) return; \ + output[gid] = RIGHT_TYPENAME(input[gid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -41,16 +35,10 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint threadgroup_position_in_grid [[ threadgroup_position_in_grid ]], \ - uint thread_position_in_threadgroup [[ thread_position_in_threadgroup ]], \ - uint threads_per_threadgroup [[ threads_per_threadgroup ]]) \ -{ \ - uint i = (threadgroup_position_in_grid * threads_per_threadgroup) + \ - thread_position_in_threadgroup; \ - if (i >= dim) { \ - return; \ - } \ - output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + uint gid [[thread_position_in_grid]] \ +) { \ + if (gid >= dim) return; \ + output[gid] = RIGHT_TYPENAME(input[get_strided_index(gid, num_dims, dims, strides)]); \ } \ CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index fb79a881c7..29c5bac333 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -190,8 +190,7 @@ pub fn call_unary_contiguous( pipeline.max_total_threads_per_threadgroup(), length as NSUInteger, ); - let remainder = length as NSUInteger % threads; - let thread_groups = length as NSUInteger / threads + if remainder == 0 { 0 } else { 1 }; + let thread_groups = (length as NSUInteger + threads - 1) / threads; let diff = (threads * thread_groups) - length as NSUInteger; let threads = threads - (diff / thread_groups); @@ -257,8 +256,7 @@ pub fn call_unary_strided( pipeline.max_total_threads_per_threadgroup(), length as NSUInteger, ); - let remainder = length as NSUInteger % threads; - let thread_groups = length as NSUInteger / threads + if remainder == 0 { 0 } else { 1 }; + let thread_groups = (length as NSUInteger + threads - 1) / threads; let diff = (threads * thread_groups) - length as NSUInteger; let threads = threads - (diff / thread_groups); @@ -312,8 +310,7 @@ pub fn call_binary_contiguous( pipeline.max_total_threads_per_threadgroup(), length as NSUInteger, ); - let remainder = length as NSUInteger % threads; - let thread_groups = length as NSUInteger / threads + if remainder == 0 { 0 } else { 1 }; + let thread_groups = (length as NSUInteger + threads - 1) / threads; let diff = (threads * thread_groups) - length as NSUInteger; let threads = threads - (diff / thread_groups); @@ -389,8 +386,7 @@ pub fn call_binary_strided( pipeline.max_total_threads_per_threadgroup(), length as NSUInteger, ); - let remainder = length as NSUInteger % threads; - let thread_groups = length as NSUInteger / threads + if remainder == 0 { 0 } else { 1 }; + let thread_groups = (length as NSUInteger + threads - 1) / threads; let diff = (threads * thread_groups) - length as NSUInteger; let threads = threads - (diff / thread_groups); @@ -442,8 +438,7 @@ pub fn call_cast_contiguous( pipeline.max_total_threads_per_threadgroup(), length as NSUInteger, ); - let remainder = length as NSUInteger % threads; - let thread_groups = length as NSUInteger / threads + if remainder == 0 { 0 } else { 1 }; + let thread_groups = (length as NSUInteger + threads - 1) / threads; let diff = (threads * thread_groups) - length as NSUInteger; let threads = threads - (diff / thread_groups); @@ -617,8 +612,7 @@ pub fn call_affine( pipeline.max_total_threads_per_threadgroup(), length as NSUInteger, ); - let remainder = length as NSUInteger % threads; - let thread_groups = length as NSUInteger / threads + if remainder == 0 { 0 } else { 1 }; + let thread_groups = (length as NSUInteger + threads - 1) / threads; let diff = (threads * thread_groups) - length as NSUInteger; let threads = threads - (diff / thread_groups); @@ -697,20 +691,27 @@ pub fn call_where_cond_strided( encoder.set_buffer(8, Some(right), right_offset as u64); encoder.set_buffer(9, Some(output), 0); + let length = size as NSUInteger; + let threads = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + length as NSUInteger, + ); + let thread_groups = (length as NSUInteger + threads - 1) / threads; + let diff = (threads * thread_groups) - length as NSUInteger; + let threads = threads - (diff / thread_groups); + let thread_group_count = MTLSize { - width: 1, + width: thread_groups, height: 1, depth: 1, }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, + let threads_per_threadgroup = MTLSize { + width: threads, + height: threads, + depth: threads, }; - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.dispatch_thread_groups(thread_group_count, threads_per_threadgroup); encoder.end_encoding(); Ok(()) } diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 0945b355cf..7ce93ea5fd 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -17,7 +17,6 @@ METAL_FUNC uint get_strided_index( return strided_i; } - #define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ kernel void FN_NAME( \ constant size_t &numel, \ @@ -30,28 +29,38 @@ kernel void FN_NAME( \ device const TYPENAME *t, \ device const TYPENAME *f, \ device TYPENAME *out ,\ - uint i [[ thread_position_in_grid ]] \ -) { \ - uint strided_i = get_strided_index(i, num_dims, dims, strides); \ - uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ - uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ - out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ + uint gid [[thread_position_in_grid]] \ +) { \ + if (gid >= numel) return; \ + uint strided_i = get_strided_index(gid, num_dims, dims, strides); \ + uint strided_i_t = get_strided_index(gid, num_dims, dims, strides_t); \ + uint strided_i_f = get_strided_index(gid, num_dims, dims, strides_f); \ + out[gid] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ } \ -// WHERE_OP(float, int64_t, where_i64_f32) -// WHERE_OP(double, int64_t, where_i64_f64) -// WHERE_OP(uint8_t, int64_t, where_i64_u8) -// WHERE_OP(uint32_t, int64_t, where_i64_u32) -// WHERE_OP(int64_t, int64_t, where_i64_i64) -// -// WHERE_OP(float, uint32_t, where_u32_f32) -// WHERE_OP(double, uint32_t, where_u32_f64) -// WHERE_OP(uint8_t, uint32_t, where_u32_u8) -// WHERE_OP(uint32_t, uint32_t, where_u32_u32) -// WHERE_OP(int64_t, uint32_t, where_u32_i64) - +WHERE_OP(float, int64_t, where_i64_f32) +WHERE_OP(float, uint32_t, where_u32_f32) WHERE_OP(float, uint8_t, where_u8_f32) -// WHERE_OP(double, uint8_t, where_u8_f64) -// WHERE_OP(uint8_t, uint8_t, where_u8_u8) -// WHERE_OP(uint32_t, uint8_t, where_u8_u32) -// WHERE_OP(int64_t, uint8_t, where_u8_i64) + +WHERE_OP(half, int64_t, where_i64_f16) +WHERE_OP(half, uint32_t, where_u32_f16) +WHERE_OP(half, uint8_t, where_u8_f16) + +WHERE_OP(uint8_t, int64_t, where_i64_u8) +WHERE_OP(uint8_t, uint32_t, where_u32_u8) +WHERE_OP(uint8_t, uint8_t, where_u8_u8) + +WHERE_OP(uint32_t, int64_t, where_i64_u32) +WHERE_OP(uint32_t, uint32_t, where_u32_u32) +WHERE_OP(uint32_t, uint8_t, where_u8_u32) + +WHERE_OP(int64_t, int64_t, where_i64_i64) +WHERE_OP(int64_t, uint32_t, where_u32_i64) +WHERE_OP(int64_t, uint8_t, where_u8_i64) + + +#if __METAL_VERSION__ >= 310 +WHERE_OP(bfloat, int64_t, where_i64_bf16) +WHERE_OP(bfloat, uint32_t, where_u32_bf16) +WHERE_OP(bfloat, uint8_t, where_u8_bf16) +#endif \ No newline at end of file diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index d21c57b815..3b498d93a8 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -27,16 +27,10 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_position_in_grid [[ threadgroup_position_in_grid ]], \ - uint thread_position_in_threadgroup [[ thread_position_in_threadgroup ]], \ - uint threads_per_threadgroup [[ threads_per_threadgroup ]]) \ -{ \ - uint thread_position_in_grid = (threadgroup_position_in_grid * threads_per_threadgroup) + \ - thread_position_in_threadgroup; \ - if (thread_position_in_grid >= dim) { \ - return; \ - } \ - output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \ + uint gid [[thread_position_in_grid]] \ +) { \ + if (gid >= dim) return; \ + output[gid] = TYPENAME(FN(input[gid])); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -45,16 +39,10 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_position_in_grid [[ threadgroup_position_in_grid ]], \ - uint thread_position_in_threadgroup [[ thread_position_in_threadgroup ]], \ - uint threads_per_threadgroup [[ threads_per_threadgroup ]]) \ -{ \ - uint thread_position_in_grid = (threadgroup_position_in_grid * threads_per_threadgroup) + \ - thread_position_in_threadgroup; \ - if (thread_position_in_grid >= dim) { \ - return; \ - } \ - output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \ + uint gid [[thread_position_in_grid]] \ +) { \ + if (gid >= dim) return; \ + output[gid] = TYPENAME(FN(input[get_strided_index(gid, num_dims, dims, strides)])); \ } #define UNARY_OP(NAME) \