Skip to content

Commit

Permalink
Merge pull request #7 from tracel-ai/fix/cuda/precision
Browse files Browse the repository at this point in the history
Fix cuda precision
  • Loading branch information
nathanielsimard authored Jul 18, 2024
2 parents 6466d74 + dfdb4de commit 46fb38b
Show file tree
Hide file tree
Showing 28 changed files with 360 additions and 131 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ This is just the beginning.
We plan to include more utilities such as convolutions, random number generation, fast Fourier transforms, and other essential algorithms.
We are a small team also building [Burn](https://burn.dev), so don't hesitate to contribute and port algorithms; it can help more than you would imagine!

## How it works

CubeCL leverages Rust's proc macro system in a unique two-step process:

1. Parsing: The proc macro parses the GPU kernel code using the syn crate.
2. Expansion: Instead of immediately generating an Intermediate Representation (IR), the macro generates a new Rust function.

The generated function, semantically similar to the original, is responsible for creating the IR when called.
This approach differs from traditional compilers, which typically generate IR directly after parsing.
Our method enables several key features:

- **Comptime**: By not transforming the original code, it becomes remarkably easy to integrate compile-time optimizations.
- **Automatic Vectorization**: By simply vectorizing the inputs of a CubeCL function, we can determine the vectorization factor of each intermediate variable during the expansion.
- **Rust Integration**: The generated code remains valid Rust code, allowing it to be bundled without any dependency on the specific runtime.

## Design

CubeCL is designed around - you guessed it - Cubes! More specifically, it's based on cuboids, because not all axes are the same size.
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/compute/launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ impl<R: Runtime> KernelLauncher<R> {
self,
cube_count: CubeCount<R::Server>,
kernel: K,
client: ComputeClient<R::Server, R::Channel>,
client: &ComputeClient<R::Server, R::Channel>,
) {
let bindings = self.into_bindings(&client);
let bindings = self.into_bindings(client);

let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));

Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Bool, Numeric, UInt, Vectorized, F32, F64, I32, I64};
use super::{Bool, CubePrimitive, Numeric, UInt, Vectorized, F32, F64, I32, I64};
use crate::{
ir::{ConstantScalarValue, Elem, Item, Operator, Variable, Vectorization},
prelude::{index_assign, init_expand, CubeContext, KernelBuilder, KernelLauncher},
Expand Down Expand Up @@ -200,10 +200,11 @@ impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
}
}

impl<T: CubeType> ExpandElementTyped<T> {
impl<T: CubePrimitive> ExpandElementTyped<T> {
/// Create an [ExpandElementTyped] from a value that is normaly a literal.
pub fn from_lit<L: Into<Variable>>(lit: L) -> Self {
let variable: Variable = lit.into();
let variable = T::as_elem().from_constant(variable);

ExpandElementTyped::new(ExpandElement::Plain(variable))
}
Expand Down
17 changes: 17 additions & 0 deletions crates/cubecl-core/src/ir/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ impl Elem {
ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
}
}
/// Get the size in bytes.
pub fn size(&self) -> usize {
match self {
Elem::Float(kind) => match kind {
FloatKind::F16 => core::mem::size_of::<half::f16>(),
FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
FloatKind::F32 => core::mem::size_of::<f32>(),
FloatKind::F64 => core::mem::size_of::<f64>(),
},
Elem::Int(kind) => match kind {
IntKind::I32 => core::mem::size_of::<i32>(),
IntKind::I64 => core::mem::size_of::<i64>(),
},
Elem::UInt => core::mem::size_of::<u32>(),
Elem::Bool => core::mem::size_of::<bool>(),
}
}
}

impl From<Elem> for Item {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub use runtime::*;
pub use cubecl_macros::cube;
pub use cubecl_macros::CubeLaunch;
pub use cubecl_macros::CubeType;
pub use cubecl_runtime::benchmark;

/// An approximation of the subcube dimension.
pub const SUBCUBE_DIM_APPROX: usize = 16;
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let out = client.empty(core::mem::size_of::<f32>() * 256);

kernel_simple_1::launch::<R>(
client.clone(),
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(16, 16, 1),
ArrayArg::new(&lhs, 256),
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/runtime_tests/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

kernel_with_generics::launch::<F32, R>(
client.clone(),
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
ArrayArg::new(&handle, 2),
Expand All @@ -36,7 +36,7 @@ pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server,
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

kernel_without_generics::launch::<R>(
client.clone(),
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
ArrayArg::new(&handle, 2),
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-core/src/runtime_tests/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn test_slice_select<R: Runtime>(client: ComputeClient<R::Server, R::Channel
let output = client.empty(core::mem::size_of::<f32>());

slice_select::launch::<F32, R>(
client.clone(),
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),
ArrayArg::new(&input, 5),
Expand All @@ -49,7 +49,7 @@ pub fn test_slice_len<R: Runtime>(client: ComputeClient<R::Server, R::Channel>)
let output = client.empty(core::mem::size_of::<u32>());

slice_len::launch::<F32, R>(
client.clone(),
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),
ArrayArg::new(&input, 5),
Expand All @@ -67,7 +67,7 @@ pub fn test_slice_assign<R: Runtime>(client: ComputeClient<R::Server, R::Channel
let output = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));

slice_assign::launch::<F32, R>(
client.clone(),
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),
ArrayArg::new(&input, 5),
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-core/src/runtime_tests/subcube.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub fn test_subcube_sum<TestRuntime: Runtime>(
&[17.0, 5.0, 7.0, 1.0],
client.clone(),
|cube_count, cube_dim, handle| {
kernel_sum::launch::<F32, TestRuntime>(client.clone(), cube_count, cube_dim, handle)
kernel_sum::launch::<F32, TestRuntime>(&client, cube_count, cube_dim, handle)
},
);
}
Expand All @@ -63,7 +63,7 @@ pub fn test_subcube_prod<TestRuntime: Runtime>(
&[140.0, 5.0, 7.0, 1.0],
client.clone(),
|cube_dim, settings, handle| {
kernel_prod::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
kernel_prod::launch::<F32, TestRuntime>(&client, cube_dim, settings, handle)
},
);
}
Expand All @@ -75,7 +75,7 @@ pub fn test_subcube_max<TestRuntime: Runtime>(
&[7.0, 5.0, 7.0, 1.0],
client.clone(),
|cube_dim, settings, handle| {
kernel_max::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
kernel_max::launch::<F32, TestRuntime>(&client, cube_dim, settings, handle)
},
);
}
Expand All @@ -88,7 +88,7 @@ pub fn test_subcube_min<TestRuntime: Runtime>(
&[1.0, 5.0, 7.0, 1.0],
client.clone(),
|cube_dim, settings, handle| {
kernel_min::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
kernel_min::launch::<F32, TestRuntime>(&client, cube_dim, settings, handle)
},
);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/tests/frontend/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ mod tests {
let y = scope.create_local(Item::new(UInt::as_elem()));
let z = scope.create_local(Item::new(UInt::as_elem()));

cpa!(&mut scope, x = shape(input, 1));
cpa!(&mut scope, y = stride(input, 1));
cpa!(&mut scope, x = shape(input, 1u32));
cpa!(&mut scope, y = stride(input, 1u32));
cpa!(&mut scope, z = len(input));

scope.operations
Expand Down
Loading

0 comments on commit 46fb38b

Please sign in to comment.