Skip to content

Commit

Permalink
Merge pull request #36 from tracel-ai/perf/better-vectorization-cuda
Browse files Browse the repository at this point in the history
Perf/better vectorization on cuda
  • Loading branch information
nathanielsimard authored Jul 24, 2024
2 parents 271a527 + b454fc5 commit df8ef81
Show file tree
Hide file tree
Showing 15 changed files with 539 additions and 379 deletions.
11 changes: 10 additions & 1 deletion crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::marker::PhantomData;
use std::{fmt::Debug, marker::PhantomData};

use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
use alloc::sync::Arc;
Expand Down Expand Up @@ -78,6 +78,15 @@ pub enum CubeCount<S: ComputeServer> {
Dynamic(Binding<S>),
}

impl<S: ComputeServer> Debug for CubeCount<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
CubeCount::Dynamic(_) => f.write_str("binding"),
}
}
}

impl<S: ComputeServer> Clone for CubeCount<S> {
fn clone(&self) -> Self {
match self {
Expand Down
44 changes: 44 additions & 0 deletions crates/cubecl-core/src/runtime_tests/assign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use crate as cubecl;

use cubecl::prelude::*;

#[cube(launch)]
pub fn kernel_assign(output: &mut Array<F32>, vectorization: Comptime<UInt>) {
if UNIT_POS == UInt::new(0) {
let item = F32::vectorized(5.0, Comptime::get(vectorization));
output[0] = item;
}
}

pub fn test_kernel_assign_scalar<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

let vectorization = 2;

kernel_assign::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
ArrayArg::vectorized(vectorization, &handle, 2),
UInt::new(vectorization as u32),
);

let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 5.0);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_assign {
() => {
use super::*;

#[test]
fn test_assign_scalar() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::assign::test_kernel_assign_scalar::<TestRuntime>(client);
}
};
}
4 changes: 4 additions & 0 deletions crates/cubecl-core/src/runtime_tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
pub mod assign;
pub mod cmma;
pub mod launch;
pub mod slice;
pub mod subcube;
pub mod topology;

#[allow(missing_docs)]
#[macro_export]
Expand All @@ -13,5 +15,7 @@ macro_rules! testgen_all {
cubecl_core::testgen_launch!();
cubecl_core::testgen_cmma!();
cubecl_core::testgen_slice!();
cubecl_core::testgen_assign!();
cubecl_core::testgen_topology!();
};
}
57 changes: 57 additions & 0 deletions crates/cubecl-core/src/runtime_tests/topology.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate as cubecl;

use cubecl::prelude::*;

#[cube(launch)]
pub fn kernel_absolute_pos(output1: &mut Array<UInt>, output2: &mut Array<UInt>) {
if ABSOLUTE_POS >= output1.len() {
return;
}

output1[ABSOLUTE_POS] = ABSOLUTE_POS;
output2[ABSOLUTE_POS] = ABSOLUTE_POS;
}

pub fn test_kernel_topology_absolute_pos<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let cube_count = (3, 5, 7);
let cube_dim = (16, 16, 1);
let extra: u32 = 3u32;

let length =
(cube_count.0 * cube_count.1 * cube_count.2 * cube_dim.0 * cube_dim.1 * cube_dim.2) + extra;
let handle1 = client.empty(length as usize * core::mem::size_of::<u32>());
let handle2 = client.empty(length as usize * core::mem::size_of::<u32>());

kernel_absolute_pos::launch::<R>(
&client,
CubeCount::Static(cube_count.0, cube_count.1, cube_count.2),
CubeDim::new(cube_dim.0, cube_dim.1, cube_dim.2),
ArrayArg::new(&handle1, length as usize),
ArrayArg::new(&handle2, length as usize),
);

let actual = client.read(handle1.binding());
let actual = u32::from_bytes(&actual);
let mut expect: Vec<u32> = (0..length - extra).collect();
expect.push(0);
expect.push(0);
expect.push(0);

assert_eq!(actual, &expect);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_topology {
() => {
use super::*;

#[test]
fn test_topology_scalar() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::topology::test_kernel_topology_absolute_pos::<TestRuntime>(
client,
);
}
};
}
15 changes: 8 additions & 7 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashSet;

use cubecl_core::{
ir::{self as gpu, ConstantScalarValue},
Compiler,
Expand All @@ -22,6 +24,7 @@ pub struct CudaCompiler {
stride: bool,
num_inputs: usize,
num_outputs: usize,
items: HashSet<super::Item>,
}

impl Compiler for CudaCompiler {
Expand Down Expand Up @@ -86,6 +89,7 @@ impl CudaCompiler {
wmma_activated: self.wmma,
bf16: self.bf16,
f16: self.f16,
items: self.items,
}
}

Expand Down Expand Up @@ -548,13 +552,10 @@ impl CudaCompiler {
}

fn compile_item(&mut self, item: gpu::Item) -> super::Item {
match item.vectorization {
4 => super::Item::Vec4(self.compile_elem(item.elem)),
3 => super::Item::Vec3(self.compile_elem(item.elem)),
2 => super::Item::Vec2(self.compile_elem(item.elem)),
1 => super::Item::Scalar(self.compile_elem(item.elem)),
_ => panic!("Vectorization factor unsupported {:?}", item.vectorization),
}
let item = super::Item::new(self.compile_elem(item.elem), item.vectorization.into());
self.items.insert(item);
self.items.insert(item.optimized());
item
}

fn compile_elem(&mut self, value: gpu::Elem) -> super::Elem {
Expand Down
Loading

0 comments on commit df8ef81

Please sign in to comment.