Skip to content

Commit

Permalink
Softmax support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Nov 8, 2023
1 parent 44331b3 commit e1bab1b
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 18 deletions.
12 changes: 12 additions & 0 deletions candle-core/src/dummy_metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;

#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}

impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}

macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{metal_backend, DType, DeviceLocation, Layout, Shape};
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};

#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
Expand Down Expand Up @@ -163,7 +163,7 @@ pub enum Error {
Cuda(Box<dyn std::error::Error + Send + Sync>),

#[error("Metal error {0}")]
Metal(#[from] metal_backend::MetalError),
Metal(#[from] MetalError),

#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
Expand Down
7 changes: 3 additions & 4 deletions candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,12 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "accelerate")]
mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
Expand Down Expand Up @@ -92,10 +91,10 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};

#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalStorage};
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};

#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
Expand Down
24 changes: 19 additions & 5 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use half::{bf16, f16};
use metal;
use metal::mps::matrix::encode_gemm;
use metal::mps::Float32;
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
use metal::{Buffer, CommandQueue, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
use std::sync::Arc;
use tracing::debug;

Expand Down Expand Up @@ -64,7 +64,19 @@ impl MetalDevice {
self.registry_id()
}

fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}

pub fn kernels(&self) -> &Kernels {
&self.kernels
}

pub fn device(&self) -> &metal::Device {
&self.device
}

pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64;
// debug!("Allocate 1 - buffer size {size}");
self.device
Expand Down Expand Up @@ -242,13 +254,11 @@ impl BackendStorage for MetalStorage {
);
}

let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"cast {:?} - {:?} - {:?} - {:?}",
"cast {:?} - {:?} - {:?}",
dtype,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Expand Down Expand Up @@ -668,6 +678,10 @@ impl MetalStorage {
_ => todo!("Unimplemented matmul for this pair"),
}
}

pub fn buffer(&self) -> &Buffer {
&self.buffer
}
}

impl BackendDevice for MetalDevice {
Expand Down
118 changes: 118 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,64 @@ pub fn call_reduce_contiguous(
Ok(())
}

pub fn call_last_softmax(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));

let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();

let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);

encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&length));
encoder.set_bytes(
1,
core::mem::size_of::<usize>() as u64,
void_ptr(&elements_to_sum),
);
encoder.set_buffer(2, Some(&input), 0);
encoder.set_buffer(3, Some(&output), 0);

let out_length = length / elements_to_sum;

let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};

let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
// (elements_to_sum as u64 + 2 - 1) / 2,
elements_to_sum as u64,
)
.next_power_of_two();

let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};

encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}

pub fn void_ptr<T>(v: &T) -> *const c_void {
(v as *const T).cast()
}
Expand Down Expand Up @@ -954,6 +1012,39 @@ mod tests {
output.read_to_vec::<T>(out_length)
}

fn run_softmax<T: Clone + std::fmt::Debug>(
v: &[T],
last_dim: usize,
name: &'static str,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
call_last_softmax(
&device,
&command_buffer,
&kernels,
name,
v.len(),
last_dim,
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();

output.read_to_vec::<T>(v.len())
}

#[test]
fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
Expand All @@ -971,4 +1062,31 @@ mod tests {
let results = run_reduce(&v, out_length, "fast_sum_float");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}

#[test]
fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);

let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);

let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3;
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
}
}
67 changes: 66 additions & 1 deletion candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ kernel void fast_sum_float(
uint blockDim [[ threads_per_threadgroup ]]
) {

threadgroup int shared_memory[THREADGROUP_SIZE];
threadgroup float shared_memory[THREADGROUP_SIZE];

shared_memory[tid] = 0;
// Elements summed in this block range from dst_id * el_to_sum_per_block
Expand Down Expand Up @@ -57,3 +57,68 @@ kernel void fast_sum_float(

dst[dst_id] = shared_memory[0];
}

kernel void softmax_float(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const float *src,
device float *dst,
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint blockDim [[ threads_per_threadgroup ]]
) {

threadgroup float shared_memory[THREADGROUP_SIZE];

shared_memory[tid] = -INFINITY;
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;

while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
shared_memory[tid] = max(shared_memory[tid], src[idx]);
idx += blockDim;
}

threadgroup_barrier(mem_flags::mem_none);

// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
}
threadgroup_barrier(mem_flags::mem_none);
}

float max = shared_memory[0];

shared_memory[tid] = 0;

// Restart
idx = start_idx + tid;
while (idx < stop_idx) {
// TODO: Fast version for the contiguous case.
const float val = exp(src[idx] - max);
dst[idx] = val;
shared_memory[tid] += val;
idx += blockDim;
}
// reduction in shared memory
for (uint s = blockDim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] += shared_memory[tid + s];
}
threadgroup_barrier(mem_flags::mem_none);
}

const float inv_acc = 1/shared_memory[0];
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += blockDim;
}
}
3 changes: 2 additions & 1 deletion candle-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
Expand All @@ -29,5 +30,5 @@ clap = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
metal = ["candle/metal"]
metal = ["candle/metal", "candle-metal-kernels"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
5 changes: 3 additions & 2 deletions candle-nn/examples/cpu_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,11 @@ impl Benchmark for QMatMul {
type PreProcessData = (candle::quantized::QMatMul, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let device = Device::Cpu;
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
let mm = candle::quantized::QTensor::new(zeros, (4096, 11008), &device)?;
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
let arg = Tensor::randn(0f32, 1., (128, 11008), &device)?;
Ok((mm, arg))
}

Expand Down
Loading

0 comments on commit e1bab1b

Please sign in to comment.