Skip to content

Commit

Permalink
Unchecked Execution Mode (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Aug 8, 2024
1 parent dd274b6 commit 30d090c
Show file tree
Hide file tree
Showing 49 changed files with 1,271 additions and 849 deletions.
24 changes: 13 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Simply annotate functions with the `cube` attribute to indicate that they should
```rust
use cubecl::prelude::*;

#[cube(launch)]
#[cube(launch_unchecked)]
fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
Expand All @@ -38,24 +38,26 @@ fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
fn gelu_scalar<F: Float>(x: F) -> F {
x * (F::erf(x / F::sqrt(2.0.into())) + 1.0) / 2.0
}

```

You can then launch the kernel using the autogenerated `gelu_array::launch` function.
You can then launch the kernel using the autogenerated `gelu_array::launch_unchecked` function.

```rust
fn launch<R: Runtime>(device: &R::Device) {
let client = R::client(device);
let input = &[-1., 0., 1., 5.];
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());

gelu_array::launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32, 1, 1),
ArrayArg::new(&client.create(f32::as_bytes(input)), input.len()),
ArrayArg::new(&output_handle, input.len()),
);
let input_handle = client.create(f32::as_bytes(input));

unsafe {
gelu_array::launch_unchecked::<F32, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32, 1, 1),
ArrayArg::from_raw_parts(&input_handle, input.len(), 1),
ArrayArg::from_raw_parts(&output_handle, input.len(), 1),
)
};

let bytes = client.read(output_handle.binding());
let output = f32::from_bytes(&bytes);
Expand Down
3 changes: 2 additions & 1 deletion crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::ir::{Elem, KernelDefinition};
use cubecl_runtime::ExecutionMode;
use std::fmt::Display;

/// Trait for compiled code representation
Expand All @@ -13,7 +14,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
type Representation: CompilerRepresentation;

/// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
fn compile(kernel: KernelDefinition) -> Self::Representation;
fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
/// The maximal size of a shared memory
Expand Down
19 changes: 11 additions & 8 deletions crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::{

use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel, KernelId};
use alloc::sync::Arc;
use cubecl_runtime::server::{Binding, ComputeServer};
use cubecl_runtime::{
server::{Binding, ComputeServer},
ExecutionMode,
};

/// A kernel, compiled in the target language
pub struct CompiledKernel {
Expand Down Expand Up @@ -157,7 +160,7 @@ pub trait CubeTask: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> KernelId;
/// Compile the kernel into source
fn compile(&self) -> CompiledKernel;
fn compile(&self, mode: ExecutionMode) -> CompiledKernel;
}

/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
Expand All @@ -168,10 +171,10 @@ pub struct KernelTask<C: Compiler, K: Kernel> {
}

impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
fn compile(&self) -> CompiledKernel {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel {
let gpu_ir = self.kernel_definition.define();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = C::compile(gpu_ir);
let lower_level_ir = C::compile(gpu_ir, mode);
let shared_mem_bytes = lower_level_ir.shared_memory_size();
let source = lower_level_ir.to_string();

Expand All @@ -190,8 +193,8 @@ impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
}

impl CubeTask for Arc<dyn CubeTask> {
fn compile(&self) -> CompiledKernel {
self.as_ref().compile()
fn compile(&self, mode: ExecutionMode) -> CompiledKernel {
self.as_ref().compile(mode)
}

fn id(&self) -> KernelId {
Expand All @@ -200,8 +203,8 @@ impl CubeTask for Arc<dyn CubeTask> {
}

impl CubeTask for Box<dyn CubeTask> {
fn compile(&self) -> CompiledKernel {
self.as_ref().compile()
fn compile(&self, mode: ExecutionMode) -> CompiledKernel {
self.as_ref().compile(mode)
}

fn id(&self) -> KernelId {
Expand Down
61 changes: 36 additions & 25 deletions crates/cubecl-core/src/compute/launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,24 @@ impl<R: Runtime> KernelLauncher<R> {
client.execute(kernel, cube_count, bindings);
}

/// Launch the kernel without check bounds.
///
/// # Safety
///
/// Out-of-bounds reads and writes can happen.
pub unsafe fn launch_unchecked<K: Kernel>(
self,
cube_count: CubeCount<R::Server>,
kernel: K,
client: &ComputeClient<R::Server, R::Channel>,
) {
let bindings = self.into_bindings(client);

let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));

client.execute_unchecked(kernel, cube_count, bindings);
}

/// We need to create the bindings in the same order they are defined in the compilation step.
///
/// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed
Expand Down Expand Up @@ -174,21 +192,16 @@ impl<R: Runtime> TensorState<R> {

bindings.push(tensor.handle.clone().binding());

let old_rank = if metadata.is_empty() {
if metadata.is_empty() {
let rank = tensor.strides.len() as u32;
metadata.push(rank);
None
} else if tensor.strides.len() > metadata[0] as usize {
let old_rank = metadata[0];
let rank = tensor.strides.len() as u32;
Self::adjust_rank(metadata, bindings.len(), rank);
Some(old_rank)
} else {
None
};
Self::adjust_rank(metadata, bindings.len() - 1, rank);
}

Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata);
Self::register_shape(tensor.shape, old_rank, metadata);
Self::register_strides(tensor.strides, tensor.shape, None, metadata);
Self::register_shape(tensor.shape, None, metadata);

if R::require_array_lengths() {
let len = calculate_num_elems_dyn_rank(tensor.shape);
Expand All @@ -200,6 +213,7 @@ impl<R: Runtime> TensorState<R> {
let old_rank = metadata[0] as usize;
let rank_diff = rank as usize - old_rank;
let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered);
updated_metadata.push(rank);

for pos in 0..num_registered {
let stride_index = (pos * old_rank * 2) + 1;
Expand Down Expand Up @@ -228,19 +242,14 @@ impl<R: Runtime> TensorState<R> {
) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = old_rank - rank;
let padded_strides = if rank_diff > 0 {
shape
.iter()
.take(old_rank as usize)
.map(|a| a.to_u32().unwrap())
.sum::<u32>()
} else {
0
};
let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize;

if rank_diff > 0 {
let padded_strides = shape.iter().map(|a| a.to_u32().unwrap()).sum::<u32>();

for _ in 0..rank_diff {
output.push(padded_strides.to_u32().unwrap());
for _ in 0..rank_diff {
output.push(padded_strides);
}
}

old_rank as usize
Expand All @@ -256,10 +265,12 @@ impl<R: Runtime> TensorState<R> {
fn register_shape<T: ToPrimitive>(shape: &[T], old_rank: Option<u32>, output: &mut Vec<u32>) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = rank - old_rank;
let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize;

for _ in 0..rank_diff {
output.push(1);
if rank_diff > 0 {
for _ in 0..rank_diff {
output.push(1);
}
}

old_rank as usize
Expand Down
31 changes: 17 additions & 14 deletions crates/cubecl-core/src/frontend/element/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle).
pub struct ArrayHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle<R::Server>,
pub length: [usize; 1],
pub(crate) length: [usize; 1],
}

pub enum ArrayArg<'a, R: Runtime> {
Expand Down Expand Up @@ -205,35 +205,38 @@ impl<'a, R: Runtime> ArgSettings<R> for ArrayArg<'a, R> {
impl<'a, R: Runtime> ArrayArg<'a, R> {
/// Create a new array argument.
///
/// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization
/// factor of 1.
pub fn new(handle: &'a cubecl_runtime::server::Handle<R::Server>, length: usize) -> Self {
ArrayArg::Handle {
handle: ArrayHandleRef::new(handle, length),
vectorization_factor: 1,
}
}
/// Create a new array argument specified with its vectorization factor.
pub fn vectorized(
vectorization_factor: u8,
/// # Safety
///
/// Specifying the wrong lenght may lead to out-of-bounds reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle<R::Server>,
length: usize,
vectorization_factor: u8,
) -> Self {
ArrayArg::Handle {
handle: ArrayHandleRef::new(handle, length),
handle: ArrayHandleRef::from_raw_parts(handle, length),
vectorization_factor,
}
}
}

impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
pub fn new(handle: &'a cubecl_runtime::server::Handle<R::Server>, length: usize) -> Self {
/// Create a new array handle reference.
///
/// # Safety
///
/// Specifying the wrong lenght may lead to out-of-bounds reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle<R::Server>,
length: usize,
) -> Self {
Self {
handle,
length: [length],
}
}

/// Return the handle as a tensor instead of an array.
pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
let shape = &self.length;

Expand Down
56 changes: 37 additions & 19 deletions crates/cubecl-core/src/frontend/element/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,36 @@ impl<C: CubePrimitive> LaunchArg for Tensor<C> {

/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle),
/// the strides and the shape.
#[derive(new)]
pub struct TensorHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle<R::Server>,
pub strides: &'a [usize],
pub shape: &'a [usize],
}

impl<'a, R: Runtime> TensorHandleRef<'a, R> {
/// Convert the handle into a [tensor argument](TensorArg).
pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
unsafe { TensorArg::from_raw_parts(self.handle, self.strides, self.shape, vectorisation) }
}
/// Create a handle from raw parts.
///
/// # Safety
///
/// If you provide wrong strides or shapes, it might create undefined behavior caused by
/// out-of-bounds reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle<R::Server>,
strides: &'a [usize],
shape: &'a [usize],
) -> Self {
Self {
handle,
strides,
shape,
}
}
}

/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
pub enum TensorArg<'a, R: Runtime> {
/// The tensor is passed with a tensor handle.
Expand All @@ -76,32 +99,27 @@ pub enum TensorArg<'a, R: Runtime> {
}

impl<'a, R: Runtime> TensorArg<'a, R> {
/// Create a new tensor argument.
/// Create a new tensor argument specified with its vectorization factor.
///
/// # Safety
///
/// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization
/// factor of 1.
pub fn new(
/// If you provide wrong strides or shapes, it might create undefined behavior caused by
/// out-of-bound reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle<R::Server>,
strides: &'a [usize],
shape: &'a [usize],
) -> Self {
Self::Handle {
handle: TensorHandleRef::new(handle, strides, shape),
vectorization_factor: 1,
}
}
/// Create a new tensor argument specified with its vectorization factor.
pub fn vectorized(
factor: u8,
handle: &'a cubecl_runtime::server::Handle<R::Server>,
strides: &'a [usize],
shape: &'a [usize],
) -> Self {
Self::Handle {
handle: TensorHandleRef::new(handle, strides, shape),
vectorization_factor: factor,
unsafe {
Self::Handle {
handle: TensorHandleRef::from_raw_parts(handle, strides, shape),
vectorization_factor: factor,
}
}
}

/// Create an alias argument.
pub fn alias(position: usize) -> Self {
Self::Alias {
input_pos: position,
Expand Down
8 changes: 8 additions & 0 deletions crates/cubecl-core/src/id.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use cubecl_runtime::ExecutionMode;
use std::any::{Any, TypeId};
use std::fmt::Display;
use std::hash::{DefaultHasher, Hash, Hasher};
Expand All @@ -8,6 +9,7 @@ use std::sync::Arc;
pub struct KernelId {
type_id: core::any::TypeId,
info: Option<Info>,
mode: Option<ExecutionMode>,
}

impl Display for KernelId {
Expand All @@ -25,6 +27,7 @@ impl KernelId {
Self {
type_id: core::any::TypeId::of::<T>(),
info: None,
mode: None,
}
}

Expand All @@ -39,6 +42,11 @@ impl KernelId {
self.info = Some(Info::new(info));
self
}

/// Set the [execution mode](ExecutionMode).
pub fn mode(&mut self, mode: ExecutionMode) {
self.mode = Some(mode);
}
}

/// Extra information
Expand Down
Loading

0 comments on commit 30d090c

Please sign in to comment.