Skip to content

Commit

Permalink
metal backend index_add implemented. better performance and benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Nov 10, 2023
1 parent c5b77d3 commit fc53b41
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 48 deletions.
79 changes: 71 additions & 8 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use crate::{bail, CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use core::mem;
Expand Down Expand Up @@ -542,14 +542,77 @@ impl BackendStorage for MetalStorage {

fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
todo!()
// TODO: Definitely refactor this
let kernel_name = match ids.dtype {
DType::U8 => match src.dtype {
DType::U8 => "ia_u8_u8",
DType::U32 => "ia_u8_u32",
DType::I64 => "ia_u8_i64",
DType::BF16 => "ia_u8_bf16",
DType::F16 => "ia_u8_f16",
DType::F32 => "ia_u8_f32",
_ => bail!("Unsupported dtype for index add"),
},
DType::U32 => match src.dtype {
DType::U8 => "ia_u32_u8",
DType::U32 => "ia_u32_u32",
DType::I64 => "ia_u32_i64",
DType::BF16 => "ia_u32_bf16",
DType::F16 => "ia_u32_f16",
DType::F32 => "ia_u32_f32",
_ => bail!("Unsupported dtype for index add"),
},
DType::I64 => match src.dtype {
DType::U8 => "ia_i64_u8",
DType::U32 => "ia_i64_u32",
DType::I64 => "ia_i64_i64",
DType::BF16 => "ia_i64_bf16",
DType::F16 => "ia_i64_f16",
DType::F32 => "ia_i64_f32",
_ => bail!("Unsupported dtype for index add"),
},
_ => bail!("Unsupported index dtype for index add"),
};

let device = self.device.clone();
let mut dst = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut dst, 0, l)?;

let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let dst_dim_size = l.dims()[dim];
let ids_dim_size = ids_l.dims()[0];

let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_index_add(
device.device(),
&command_buffer,
&device.kernels,
&kernel_name,
&ids.buffer,
&src.buffer,
&mut dst.buffer,
ids_dim_size as NSUInteger,
left_size as NSUInteger,
dst_dim_size as NSUInteger,
right_size as NSUInteger,
)
.unwrap();

command_buffer.commit();

return Ok(Self {
buffer: dst.buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
}

fn matmul(
Expand Down
220 changes: 220 additions & 0 deletions candle-metal-kernels/examples/indexing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
use candle_metal_kernels::{call_index_add, Kernels};
use half::{bf16, f16};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions, NSUInteger};
use rand;
use std::any::type_name;
use std::time::Instant;

macro_rules! ia_name {
($idx_type:literal, $dtype:literal) => {
concat!("ia_", $idx_type, "_", $dtype)
};
}

macro_rules! ia_kernels {
($dtype:literal) => {
[
ia_name!("u8", $dtype),
ia_name!("u32", $dtype),
ia_name!("i64", $dtype),
]
};
}

struct IdsData {
u8: Vec<u8>,
u32: Vec<u32>,
i64: Vec<i64>,
}

fn main() {
let device = Device::system_default().unwrap();

let kernels = Kernels::new();
let u8_1k = (0..1000).map(|_| rand::random::<u8>()).collect::<Vec<_>>();
let u8_10k = (0..10000).map(|_| rand::random::<u8>()).collect::<Vec<_>>();

let u8_100k = (0..100000)
.map(|_| rand::random::<u8>())
.collect::<Vec<_>>();
let f32_map = |v: &[u8]| v.iter().map(|v| *v as f32).collect::<Vec<_>>();
let f32_1k = f32_map(&u8_1k);
let f32_10k = f32_map(&u8_10k);

let f32_100k = f32_map(&u8_100k);
let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
let f16_1k = f16_map(&f32_1k);
let f16_10k = f16_map(&f32_10k);

let f16_100k = f16_map(&f32_100k);
let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
let bf16_1k = bf16_map(&f32_1k);
let bf16_10k = bf16_map(&f32_10k);

let bf16_100k = bf16_map(&f32_100k);
let u32_map = |v: &[u8]| v.iter().map(|v| *v as u32).collect::<Vec<_>>();
let u32_1k = u32_map(&u8_1k);
let u32_10k = u32_map(&u8_10k);

let u32_100k = u32_map(&u8_100k);
let i64_map = |v: &[u8]| v.iter().map(|v| *v as i64).collect::<Vec<_>>();
let i64_1k = i64_map(&u8_1k);
let i64_10k = i64_map(&u8_10k);

let i64_100k = i64_map(&u8_100k);

let f32_kernels = ia_kernels!("f32");
let f16_kernels = ia_kernels!("f16");
let bf16_kernels = ia_kernels!("bf16");
let u32_kernels = ia_kernels!("u32");
let i64_kernels = ia_kernels!("u32");

println!(
"{0: <5} | {1: <11} | {2: <6} | {3: <5} | {4: <12} | {5}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);

let ids_data_1k = IdsData {
u8: u8_1k.clone(),
u32: u32_1k.clone(),
i64: i64_1k.clone(),
};
let ids_data_10k = IdsData {
u8: u8_10k.clone(),
u32: u32_10k.clone(),
i64: i64_10k.clone(),
};
let ids_data_100k = IdsData {
u8: u8_100k.clone(),
u32: u32_100k.clone(),
i64: i64_100k.clone(),
};

// f32
run_indexing_benches(&device, &kernels, &ids_data_1k, &f32_1k, f32_kernels);
run_indexing_benches(&device, &kernels, &ids_data_10k, &f32_10k, f32_kernels);
run_indexing_benches(&device, &kernels, &ids_data_100k, &f32_100k, f32_kernels);

// f16
run_indexing_benches(&device, &kernels, &ids_data_1k, &f16_1k, f16_kernels);
run_indexing_benches(&device, &kernels, &ids_data_10k, &f16_10k, f16_kernels);
run_indexing_benches(&device, &kernels, &ids_data_100k, &f16_100k, f16_kernels);

// bf16
run_indexing_benches(&device, &kernels, &ids_data_1k, &bf16_1k, bf16_kernels);
run_indexing_benches(&device, &kernels, &ids_data_10k, &bf16_10k, bf16_kernels);
run_indexing_benches(&device, &kernels, &ids_data_100k, &bf16_100k, bf16_kernels);

// u8
run_indexing_benches(&device, &kernels, &ids_data_1k, &u8_1k, u32_kernels);
run_indexing_benches(&device, &kernels, &ids_data_10k, &u8_10k, u32_kernels);
run_indexing_benches(&device, &kernels, &ids_data_100k, &u8_100k, u32_kernels);

// u32
run_indexing_benches(&device, &kernels, &ids_data_1k, &u32_1k, u32_kernels);
run_indexing_benches(&device, &kernels, &ids_data_10k, &u32_10k, u32_kernels);
run_indexing_benches(&device, &kernels, &ids_data_100k, &u32_100k, u32_kernels);

// i64
run_indexing_benches(&device, &kernels, &ids_data_1k, &i64_1k, i64_kernels);
run_indexing_benches(&device, &kernels, &ids_data_10k, &i64_10k, i64_kernels);
run_indexing_benches(&device, &kernels, &ids_data_100k, &i64_100k, i64_kernels);
}

fn run_indexing_benches<T: Clone>(
device: &Device,
kernels: &Kernels,
ids: &IdsData,
input: &[T],
index_add_kernels: [&'static str; 3],
) {
run_indexing_bench(device, kernels, &ids.u8, input, index_add_kernels[0]);
run_indexing_bench(device, kernels, &ids.u32, input, index_add_kernels[1]);
run_indexing_bench(device, kernels, &ids.i64, input, index_add_kernels[2]);
}

fn run_indexing_bench<T: Clone, U: Clone>(
device: &Device,
kernels: &Kernels,
ids: &[T],
input: &[U],
kernel_name: &'static str,
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;

let iterations = 10000;
let ids_buffer = device.new_buffer_with_data(
ids.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(ids) as u64,
options,
);
let input_buffer = device.new_buffer_with_data(
input.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(input) as u64,
options,
);
let mut output_buffer = device.new_buffer(core::mem::size_of_val(input) as u64, options);

let ids_dim_size = ids.len() as NSUInteger;
let left_size = input.len() as NSUInteger;
let dst_dim_size = ids.len() as NSUInteger;
let right_size = input.len() as NSUInteger;

// Ghost pass to ensure kernel load time is not included in benchmarks
autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
call_index_add(
device,
&command_buffer,
kernels,
kernel_name,
&ids_buffer,
&input_buffer,
&mut output_buffer,
ids_dim_size,
left_size,
dst_dim_size,
right_size,
)
.unwrap();

command_buffer.commit();
command_buffer.wait_until_completed();
});

let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_index_add(
device,
&command_buffer,
kernels,
kernel_name,
&ids_buffer,
&input_buffer,
&mut output_buffer,
ids_dim_size,
left_size,
dst_dim_size,
right_size,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();

start.elapsed()
});
println!(
"{0: <5} | {1: <11} | {2: <6} | {3: <5} | {4: <12?} | {5:?}",
type_name::<U>().split("::").last().unwrap(),
kernel_name.to_string(),
input.len(),
iterations,
total_time,
total_time / iterations
);
}
54 changes: 23 additions & 31 deletions candle-metal-kernels/src/indexing.metal
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,12 @@ void index_add(
constant uint &dst_dim_size,
constant uint &right_size,

uint threadgroup_size [[threads_per_threadgroup]],
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
uint gid [[thread_position_in_grid]]
) {
if (gid >= left_size * right_size) return;

const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) {
return;
}

const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
const uint pre = gid / right_size;
const uint post = gid % right_size;

for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
Expand All @@ -43,33 +36,32 @@ kernel void FN_NAME( \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
uint gid [[thread_position_in_grid]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \


IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(float, int64_t, ia_i64_f32)

#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif

IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, int64_t, ia_i64_f16)

IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)

IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(uint32_t, int64_t, ia_i64_u32)

IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(int64_t, int64_t, ia_i64_i64)

#if __METAL_VERSION__ >= 310
IA_OP(bfloat, uint8_t, ia_u8_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, int64_t, ia_i64_bf16)
#endif
Loading

0 comments on commit fc53b41

Please sign in to comment.