From fc53b4132ea4796aff84926af42339b61ed3fb84 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 10 Nov 2023 17:16:56 +0100 Subject: [PATCH] metal backend index_add implemented. better performance and benchmark --- candle-core/src/metal_backend.rs | 79 +++++++- candle-metal-kernels/examples/indexing.rs | 220 ++++++++++++++++++++++ candle-metal-kernels/src/indexing.metal | 54 +++--- candle-metal-kernels/src/lib.rs | 103 +++++++++- 4 files changed, 408 insertions(+), 48 deletions(-) create mode 100644 candle-metal-kernels/examples/indexing.rs diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index e800231fc7..95d8a3d9f4 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1,7 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{CpuStorage, DType, Layout, Result, Shape}; +use crate::{bail, CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; use core::mem; @@ -542,14 +542,77 @@ impl BackendStorage for MetalStorage { fn index_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result { - todo!() + // TODO: Definitely refactor this + let kernel_name = match ids.dtype { + DType::U8 => match src.dtype { + DType::U8 => "ia_u8_u8", + DType::U32 => "ia_u8_u32", + DType::I64 => "ia_u8_i64", + DType::BF16 => "ia_u8_bf16", + DType::F16 => "ia_u8_f16", + DType::F32 => "ia_u8_f32", + _ => bail!("Unsupported dtype for index add"), + }, + DType::U32 => match src.dtype { + DType::U8 => "ia_u32_u8", + DType::U32 => "ia_u32_u32", + DType::I64 => "ia_u32_i64", + DType::BF16 => "ia_u32_bf16", + DType::F16 => "ia_u32_f16", + DType::F32 => "ia_u32_f32", + _ => bail!("Unsupported dtype for index add"), + }, + DType::I64 => match src.dtype { + DType::U8 => "ia_i64_u8", + DType::U32 => "ia_i64_u32", + DType::I64 => "ia_i64_i64", + DType::BF16 => "ia_i64_bf16", + DType::F16 => "ia_i64_f16", + DType::F32 => "ia_i64_f32", + _ => bail!("Unsupported dtype for index add"), + }, + _ => bail!("Unsupported index dtype for index add"), + }; + + let device = self.device.clone(); + let mut dst = device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut dst, 0, l)?; + + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let dst_dim_size = l.dims()[dim]; + let ids_dim_size = ids_l.dims()[0]; + + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_index_add( + device.device(), + &command_buffer, + &device.kernels, + &kernel_name, + &ids.buffer, + &src.buffer, + &mut dst.buffer, + ids_dim_size as NSUInteger, + left_size as NSUInteger, + dst_dim_size as NSUInteger, + right_size as NSUInteger, + ) + .unwrap(); + + command_buffer.commit(); + + return Ok(Self { + buffer: dst.buffer, + device: self.device.clone(), + dtype: self.dtype(), + }); } fn matmul( diff --git a/candle-metal-kernels/examples/indexing.rs b/candle-metal-kernels/examples/indexing.rs new file mode 100644 index 0000000000..6a09cff1dc --- /dev/null +++ b/candle-metal-kernels/examples/indexing.rs @@ -0,0 +1,220 @@ +use candle_metal_kernels::{call_index_add, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions, NSUInteger}; +use rand; +use std::any::type_name; +use std::time::Instant; + +macro_rules! ia_name { + ($idx_type:literal, $dtype:literal) => { + concat!("ia_", $idx_type, "_", $dtype) + }; +} + +macro_rules! ia_kernels { + ($dtype:literal) => { + [ + ia_name!("u8", $dtype), + ia_name!("u32", $dtype), + ia_name!("i64", $dtype), + ] + }; +} + +struct IdsData { + u8: Vec, + u32: Vec, + i64: Vec, +} + +fn main() { + let device = Device::system_default().unwrap(); + + let kernels = Kernels::new(); + let u8_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let u8_10k = (0..10000).map(|_| rand::random::()).collect::>(); + + let u8_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + let f32_map = |v: &[u8]| v.iter().map(|v| *v as f32).collect::>(); + let f32_1k = f32_map(&u8_1k); + let f32_10k = f32_map(&u8_10k); + + let f32_100k = f32_map(&u8_100k); + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + + let f16_100k = f16_map(&f32_100k); + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + + let bf16_100k = bf16_map(&f32_100k); + let u32_map = |v: &[u8]| v.iter().map(|v| *v as u32).collect::>(); + let u32_1k = u32_map(&u8_1k); + let u32_10k = u32_map(&u8_10k); + + let u32_100k = u32_map(&u8_100k); + let i64_map = |v: &[u8]| v.iter().map(|v| *v as i64).collect::>(); + let i64_1k = i64_map(&u8_1k); + let i64_10k = i64_map(&u8_10k); + + let i64_100k = i64_map(&u8_100k); + + let f32_kernels = ia_kernels!("f32"); + let f16_kernels = ia_kernels!("f16"); + let bf16_kernels = ia_kernels!("bf16"); + let u32_kernels = ia_kernels!("u32"); + let i64_kernels = ia_kernels!("u32"); + + println!( + "{0: <5} | {1: <11} | {2: <6} | {3: <5} | {4: <12} | {5}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + let ids_data_1k = IdsData { + u8: u8_1k.clone(), + u32: u32_1k.clone(), + i64: i64_1k.clone(), + }; + let ids_data_10k = IdsData { + u8: u8_10k.clone(), + u32: u32_10k.clone(), + i64: i64_10k.clone(), + }; + let ids_data_100k = IdsData { + u8: u8_100k.clone(), + u32: u32_100k.clone(), + i64: i64_100k.clone(), + }; + + // f32 + run_indexing_benches(&device, &kernels, &ids_data_1k, &f32_1k, f32_kernels); + run_indexing_benches(&device, &kernels, &ids_data_10k, &f32_10k, f32_kernels); + run_indexing_benches(&device, &kernels, &ids_data_100k, &f32_100k, f32_kernels); + + // f16 + run_indexing_benches(&device, &kernels, &ids_data_1k, &f16_1k, f16_kernels); + run_indexing_benches(&device, &kernels, &ids_data_10k, &f16_10k, f16_kernels); + run_indexing_benches(&device, &kernels, &ids_data_100k, &f16_100k, f16_kernels); + + // bf16 + run_indexing_benches(&device, &kernels, &ids_data_1k, &bf16_1k, bf16_kernels); + run_indexing_benches(&device, &kernels, &ids_data_10k, &bf16_10k, bf16_kernels); + run_indexing_benches(&device, &kernels, &ids_data_100k, &bf16_100k, bf16_kernels); + + // u8 + run_indexing_benches(&device, &kernels, &ids_data_1k, &u8_1k, u32_kernels); + run_indexing_benches(&device, &kernels, &ids_data_10k, &u8_10k, u32_kernels); + run_indexing_benches(&device, &kernels, &ids_data_100k, &u8_100k, u32_kernels); + + // u32 + run_indexing_benches(&device, &kernels, &ids_data_1k, &u32_1k, u32_kernels); + run_indexing_benches(&device, &kernels, &ids_data_10k, &u32_10k, u32_kernels); + run_indexing_benches(&device, &kernels, &ids_data_100k, &u32_100k, u32_kernels); + + // i64 + run_indexing_benches(&device, &kernels, &ids_data_1k, &i64_1k, i64_kernels); + run_indexing_benches(&device, &kernels, &ids_data_10k, &i64_10k, i64_kernels); + run_indexing_benches(&device, &kernels, &ids_data_100k, &i64_100k, i64_kernels); +} + +fn run_indexing_benches( + device: &Device, + kernels: &Kernels, + ids: &IdsData, + input: &[T], + index_add_kernels: [&'static str; 3], +) { + run_indexing_bench(device, kernels, &ids.u8, input, index_add_kernels[0]); + run_indexing_bench(device, kernels, &ids.u32, input, index_add_kernels[1]); + run_indexing_bench(device, kernels, &ids.i64, input, index_add_kernels[2]); +} + +fn run_indexing_bench( + device: &Device, + kernels: &Kernels, + ids: &[T], + input: &[U], + kernel_name: &'static str, +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let ids_buffer = device.new_buffer_with_data( + ids.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(ids) as u64, + options, + ); + let input_buffer = device.new_buffer_with_data( + input.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(input) as u64, + options, + ); + let mut output_buffer = device.new_buffer(core::mem::size_of_val(input) as u64, options); + + let ids_dim_size = ids.len() as NSUInteger; + let left_size = input.len() as NSUInteger; + let dst_dim_size = ids.len() as NSUInteger; + let right_size = input.len() as NSUInteger; + + // Ghost pass to ensure kernel load time is not included in benchmarks + autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + call_index_add( + device, + &command_buffer, + kernels, + kernel_name, + &ids_buffer, + &input_buffer, + &mut output_buffer, + ids_dim_size, + left_size, + dst_dim_size, + right_size, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + }); + + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_index_add( + device, + &command_buffer, + kernels, + kernel_name, + &ids_buffer, + &input_buffer, + &mut output_buffer, + ids_dim_size, + left_size, + dst_dim_size, + right_size, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <11} | {2: <6} | {3: <5} | {4: <12?} | {5:?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + input.len(), + iterations, + total_time, + total_time / iterations + ); +} diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 528c109d9d..8787c29371 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -12,19 +12,12 @@ void index_add( constant uint &dst_dim_size, constant uint &right_size, - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], - uint thread_index [[thread_index_in_threadgroup]] + uint gid [[thread_position_in_grid]] ) { + if (gid >= left_size * right_size) return; - const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size); - if (gid >= left_size * right_size) { - return; - } - - const uint i = gid; - const uint pre = i / right_size; - const uint post = i % right_size; + const uint pre = gid / right_size; + const uint post = gid % right_size; for (uint j = 0; j < ids_dim_size; j++) { const uint idx = ids[j]; @@ -43,33 +36,32 @@ kernel void FN_NAME( \ constant uint &left_size, \ constant uint &dst_dim_size, \ constant uint &right_size, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ - uint thread_index [[thread_index_in_threadgroup]] \ -) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ + uint gid [[thread_position_in_grid]] \ +) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \ +IA_OP(float, uint8_t, ia_u8_f32) +IA_OP(float, uint32_t, ia_u32_f32) +IA_OP(float, int64_t, ia_i64_f32) -#if __METAL_VERSION__ >= 310 -IA_OP(bfloat, int64_t, ia_i64_bf16) -IA_OP(bfloat, uint32_t, ia_u32_bf16) -IA_OP(bfloat, uint8_t, ia_u8_bf16) -#endif - -IA_OP(half, uint32_t, ia_u32_f16) IA_OP(half, uint8_t, ia_u8_f16) +IA_OP(half, uint32_t, ia_u32_f16) +IA_OP(half, int64_t, ia_i64_f16) -IA_OP(float, int64_t, ia_i64_f32) +IA_OP(uint8_t, uint8_t, ia_u8_u8) +IA_OP(uint8_t, uint32_t, ia_u32_u8) IA_OP(uint8_t, int64_t, ia_i64_u8) -IA_OP(int64_t, int64_t, ia_i64_i64) -IA_OP(uint32_t, int64_t, ia_i64_u32) -IA_OP(float, uint32_t, ia_u32_f32) -IA_OP(uint8_t, uint32_t, ia_u32_u8) -IA_OP(int64_t, uint32_t, ia_u32_i64) +IA_OP(uint32_t, uint8_t, ia_u8_u32) IA_OP(uint32_t, uint32_t, ia_u32_u32) +IA_OP(uint32_t, int64_t, ia_i64_u32) -IA_OP(float, uint8_t, ia_u8_f32) -IA_OP(uint8_t, uint8_t, ia_u8_u8) -IA_OP(uint32_t, uint8_t, ia_u8_u32) IA_OP(int64_t, uint8_t, ia_u8_i64) +IA_OP(int64_t, uint32_t, ia_u32_i64) +IA_OP(int64_t, int64_t, ia_i64_i64) + +#if __METAL_VERSION__ >= 310 +IA_OP(bfloat, uint8_t, ia_u8_bf16) +IA_OP(bfloat, uint32_t, ia_u32_bf16) +IA_OP(bfloat, int64_t, ia_i64_bf16) +#endif \ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 29c5bac333..8b0ee35555 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -26,7 +26,7 @@ pub enum Source { Reduce, } -macro_rules! ops{ +macro_rules! ops { ($($name:ident),+) => { pub mod contiguous { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -716,6 +716,83 @@ pub fn call_where_cond_strided( Ok(()) } +pub fn call_index_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + ids: &Buffer, + input: &Buffer, + output: &mut Buffer, + ids_dim_size: NSUInteger, + left_size: NSUInteger, + dst_dim_size: NSUInteger, + right_size: NSUInteger, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Indexing, name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, output.length() as NSUInteger); + + encoder.set_buffer(0, Some(&ids), 0); + encoder.set_buffer(1, Some(&input), 0); + encoder.set_buffer(2, Some(&output), 0); + + encoder.set_bytes( + 3, + core::mem::size_of_val(&ids_dim_size) as NSUInteger, + void_ptr(&ids_dim_size), + ); + encoder.set_bytes( + 4, + core::mem::size_of_val(&left_size) as NSUInteger, + void_ptr(&left_size), + ); + encoder.set_bytes( + 5, + core::mem::size_of_val(&dst_dim_size) as NSUInteger, + void_ptr(&dst_dim_size), + ); + encoder.set_bytes( + 6, + core::mem::size_of_val(&right_size) as NSUInteger, + void_ptr(&right_size), + ); + + let length = left_size * right_size; + let threads = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + length as NSUInteger, + ); + let thread_groups = (length as NSUInteger + threads - 1) / threads; + let diff = (threads * thread_groups) - length as NSUInteger; + let threads = threads - (diff / thread_groups); + + let thread_group_count = MTLSize { + width: thread_groups, + height: 1, + depth: 1, + }; + let threads_per_threadgroup = MTLSize { + width: threads, + height: threads, + depth: threads, + }; + + encoder.dispatch_thread_groups(thread_group_count, threads_per_threadgroup); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -1058,19 +1135,27 @@ mod tests { encoder.set_bytes(5, 4, void_ptr(&dst_dim_size)); encoder.set_bytes(6, 4, void_ptr(&right_size)); - let grid_size = MTLSize { - width: right.len() as NSUInteger, - height: 1, - depth: 1, - }; + let length = left_size * right_size; + let threads = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + length as NSUInteger, + ); + let thread_groups = (length as NSUInteger + threads - 1) / threads; + let diff = (threads * thread_groups) - length as NSUInteger; + let threads = threads - (diff / thread_groups); - let thread_group_size = MTLSize { - width: pipeline.max_total_threads_per_threadgroup(), + let thread_group_count = MTLSize { + width: thread_groups, height: 1, depth: 1, }; + let threads_per_threadgroup = MTLSize { + width: threads, + height: threads, + depth: threads, + }; - encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.dispatch_thread_groups(thread_group_count, threads_per_threadgroup); encoder.end_encoding(); command_buffer.commit(); command_buffer.wait_until_completed();