Skip to content

Commit

Permalink
simplify threads/threadgroups calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Nov 10, 2023
1 parent b87d49e commit c5b77d3
Show file tree
Hide file tree
Showing 13 changed files with 461 additions and 134 deletions.
10 changes: 1 addition & 9 deletions candle-metal-kernels/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,4 @@ tracing = "0.1.37"

[dev-dependencies]
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
rand = "0.8.5"

[[example]]
name = "unary"
path = "examples/unary.rs"

[[example]]
name = "binary"
path = "examples/binary.rs"
rand = "0.8.5"
23 changes: 21 additions & 2 deletions candle-metal-kernels/examples/affine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn main() {
.collect::<Vec<_>>();

println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5:}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);

Expand All @@ -42,6 +42,25 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {

let mul: f32 = 1.2345;
let add: f32 = 2.3456;

// Ghost pass to ensure kernel load time is not included in benchmarks
autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
call_affine(
&device,
command_buffer,
&kernels,
v.len(),
&input,
&mut output,
mul,
add,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
});

let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
Expand All @@ -64,7 +83,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}",
type_name::<T>().split("::").last().unwrap(),
"affine",
v.len(),
Expand Down
48 changes: 45 additions & 3 deletions candle-metal-kernels/examples/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fn main() {
];

println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);

Expand Down Expand Up @@ -105,6 +105,25 @@ fn run_binary_bench<T: Clone>(
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);

// Contiguous
// Ghost pass to ensure kernel load time is not included in benchmarks
for kernel_name in contiguous {
autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
call_binary_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
});
}
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
Expand All @@ -128,7 +147,7 @@ fn run_binary_bench<T: Clone>(
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
Expand All @@ -142,6 +161,29 @@ fn run_binary_bench<T: Clone>(
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
// Ghost pass to ensure kernel load time is not included in benchmarks
for kernel_name in strided {
autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
call_binary_strided(
device,
command_buffer,
&kernels,
kernel_name,
&shape,
&input,
&strides,
offset,
&input,
&strides,
offset,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
});
}
for kernel_name in strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
Expand Down Expand Up @@ -170,7 +212,7 @@ fn run_binary_bench<T: Clone>(
});

println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
Expand Down
22 changes: 20 additions & 2 deletions candle-metal-kernels/examples/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn main() {
let contiguous_kernels = ["cast_u32_f32"];

println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5:}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);

Expand Down Expand Up @@ -48,6 +48,24 @@ fn run_cast_bench<T: Clone>(
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);

// Contiguous
// Ghost pass to ensure kernel load time is not included in benchmarks
for kernel_name in contiguous {
autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
call_cast_contiguous(
device,
&command_buffer,
kernels,
kernel_name,
v.len(),
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
});
}
for kernel_name in contiguous {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
Expand All @@ -70,7 +88,7 @@ fn run_cast_bench<T: Clone>(
start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
Expand Down
102 changes: 102 additions & 0 deletions candle-metal-kernels/examples/reduce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use candle_metal_kernels::{call_reduce_contiguous, Kernels};
use metal::objc::rc::autoreleasepool;
use metal::{Device, MTLResourceOptions};
use rand;
use std::any::type_name;
use std::time::Instant;

fn main() {
let device = Device::system_default().unwrap();
let kernels = Kernels::new();

let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
let f32_10k = (0..10000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();
let f32_100k = (0..100000)
.map(|_| rand::random::<f32>())
.collect::<Vec<_>>();

let reduce_kernels = ["fast_sum_float", "softmax_float"];

println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12} | {5}",
"dtype", "kernel", "size", "runs", "total time", "avg time"
);

// f32
run_reduce_bench(&device, &kernels, &f32_1k, reduce_kernels);
run_reduce_bench(&device, &kernels, &f32_10k, reduce_kernels);
run_reduce_bench(&device, &kernels, &f32_100k, reduce_kernels);
}

fn run_reduce_bench<T: Clone>(
device: &Device,
kernels: &Kernels,
v: &[T],
reduce_kernels: [&'static str; 2],
) {
let command_queue = device.new_command_queue();
let options = MTLResourceOptions::StorageModeManaged;

let iterations = 1000;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(v) as u64,
options,
);
let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);

// Contiguous
// Ghost pass to ensure kernel load time is not included in benchmarks
for kernel_name in reduce_kernels {
autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
call_reduce_contiguous(
&device,
command_buffer,
&kernels,
kernel_name,
v.len(),
v.len(),
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
});
}
for kernel_name in reduce_kernels {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
for _ in 0..iterations {
call_reduce_contiguous(
&device,
command_buffer,
&kernels,
kernel_name,
v.len(),
v.len(),
&input,
&mut output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();

start.elapsed()
});
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <12?} | {5:?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
v.len(),
iterations,
total_time,
total_time / iterations
);
}
}
Loading

0 comments on commit c5b77d3

Please sign in to comment.