Skip to content

Commit

Permalink
Add some metal sort kernels imported from MLX.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jan 26, 2025
1 parent 27996a1 commit 82c7998
Show file tree
Hide file tree
Showing 3 changed files with 943 additions and 0 deletions.
50 changes: 50 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
const MLX_SORT: &str = include_str!("mlx_sort.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
const RANDOM: &str = include_str!("random.metal");
const REDUCE: &str = include_str!("reduce.metal");
Expand All @@ -34,6 +35,7 @@ pub enum Source {
Fill,
Gemm,
Indexing,
MlxSort,
Quantized,
Random,
Reduce,
Expand Down Expand Up @@ -211,6 +213,7 @@ impl Kernels {
Source::Fill => FILL,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::MlxSort => MLX_SORT,
Source::Quantized => QUANTIZED,
Source::Random => RANDOM,
Source::Reduce => REDUCE,
Expand Down Expand Up @@ -2507,6 +2510,53 @@ pub fn call_arg_sort(
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_mlx_arg_sort(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
bn: usize,
nrows: usize,
ncols: usize,
src: BufferOffset,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::MlxSort, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

set_params!(
encoder,
(
&src,
dst,
ncols as i32,
1i32,
1i32,
ncols as i32,
ncols as i32
)
);

let thread_group_count = MTLSize {
width: 1,
height: nrows as u64,
depth: 1,
};
let thread_group_size = MTLSize {
width: bn as u64,
height: 1,
depth: 1,
};

encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}

#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum GemmDType {
BF16,
Expand Down
Loading

0 comments on commit 82c7998

Please sign in to comment.