Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unchecked Execution Mode #51

Merged
merged 9 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Simply annotate functions with the `cube` attribute to indicate that they should
```rust
use cubecl::prelude::*;

#[cube(launch)]
#[cube(launch_unchecked)]
fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
Expand All @@ -38,24 +38,26 @@ fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
fn gelu_scalar<F: Float>(x: F) -> F {
x * (F::erf(x / F::sqrt(2.0.into())) + 1.0) / 2.0
}

```

You can then launch the kernel using the autogenerated `gelu_array::launch` function.
You can then launch the kernel using the autogenerated `gelu_array::launch_unchecked` function.

```rust
fn launch<R: Runtime>(device: &R::Device) {
let client = R::client(device);
let input = &[-1., 0., 1., 5.];
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());

gelu_array::launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32, 1, 1),
ArrayArg::new(&client.create(f32::as_bytes(input)), input.len()),
ArrayArg::new(&output_handle, input.len()),
);
let input_handle = client.create(f32::as_bytes(input));

unsafe {
gelu_array::launch_unchecked::<F32, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32, 1, 1),
ArrayArg::from_raw_parts(&input_handle, input.len(), 1),
ArrayArg::from_raw_parts(&output_handle, input.len(), 1),
)
};

let bytes = client.read(output_handle.binding());
let output = f32::from_bytes(&bytes);
Expand Down
3 changes: 2 additions & 1 deletion crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::ir::{Elem, KernelDefinition};
use cubecl_runtime::ExecutionMode;
use std::fmt::Display;

/// Trait for compiled code representation
Expand All @@ -13,7 +14,7 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
type Representation: CompilerRepresentation;

/// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
fn compile(kernel: KernelDefinition) -> Self::Representation;
fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
/// The maximal size of a shared memory
Expand Down
19 changes: 11 additions & 8 deletions crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::{

use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel, KernelId};
use alloc::sync::Arc;
use cubecl_runtime::server::{Binding, ComputeServer};
use cubecl_runtime::{
server::{Binding, ComputeServer},
ExecutionMode,
};

/// A kernel, compiled in the target language
pub struct CompiledKernel {
Expand Down Expand Up @@ -157,7 +160,7 @@ pub trait CubeTask: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> KernelId;
/// Compile the kernel into source
fn compile(&self) -> CompiledKernel;
fn compile(&self, mode: ExecutionMode) -> CompiledKernel;
}

/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
Expand All @@ -168,10 +171,10 @@ pub struct KernelTask<C: Compiler, K: Kernel> {
}

impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
fn compile(&self) -> CompiledKernel {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel {
let gpu_ir = self.kernel_definition.define();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = C::compile(gpu_ir);
let lower_level_ir = C::compile(gpu_ir, mode);
let shared_mem_bytes = lower_level_ir.shared_memory_size();
let source = lower_level_ir.to_string();

Expand All @@ -190,8 +193,8 @@ impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
}

impl CubeTask for Arc<dyn CubeTask> {
fn compile(&self) -> CompiledKernel {
self.as_ref().compile()
fn compile(&self, mode: ExecutionMode) -> CompiledKernel {
self.as_ref().compile(mode)
}

fn id(&self) -> KernelId {
Expand All @@ -200,8 +203,8 @@ impl CubeTask for Arc<dyn CubeTask> {
}

impl CubeTask for Box<dyn CubeTask> {
fn compile(&self) -> CompiledKernel {
self.as_ref().compile()
fn compile(&self, mode: ExecutionMode) -> CompiledKernel {
self.as_ref().compile(mode)
}

fn id(&self) -> KernelId {
Expand Down
61 changes: 36 additions & 25 deletions crates/cubecl-core/src/compute/launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,24 @@ impl<R: Runtime> KernelLauncher<R> {
client.execute(kernel, cube_count, bindings);
}

/// Launch the kernel without check bounds.
///
/// # Safety
///
/// Out-of-bounds reads and writes can happen.
pub unsafe fn launch_unchecked<K: Kernel>(
self,
cube_count: CubeCount<R::Server>,
kernel: K,
client: &ComputeClient<R::Server, R::Channel>,
) {
let bindings = self.into_bindings(client);

let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));

client.execute_unchecked(kernel, cube_count, bindings);
}

/// We need to create the bindings in the same order they are defined in the compilation step.
///
/// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed
Expand Down Expand Up @@ -174,21 +192,16 @@ impl<R: Runtime> TensorState<R> {

bindings.push(tensor.handle.clone().binding());

let old_rank = if metadata.is_empty() {
if metadata.is_empty() {
let rank = tensor.strides.len() as u32;
metadata.push(rank);
None
} else if tensor.strides.len() > metadata[0] as usize {
let old_rank = metadata[0];
let rank = tensor.strides.len() as u32;
Self::adjust_rank(metadata, bindings.len(), rank);
Some(old_rank)
} else {
None
};
Self::adjust_rank(metadata, bindings.len() - 1, rank);
}

Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata);
Self::register_shape(tensor.shape, old_rank, metadata);
Self::register_strides(tensor.strides, tensor.shape, None, metadata);
Self::register_shape(tensor.shape, None, metadata);

if R::require_array_lengths() {
let len = calculate_num_elems_dyn_rank(tensor.shape);
Expand All @@ -200,6 +213,7 @@ impl<R: Runtime> TensorState<R> {
let old_rank = metadata[0] as usize;
let rank_diff = rank as usize - old_rank;
let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered);
updated_metadata.push(rank);

for pos in 0..num_registered {
let stride_index = (pos * old_rank * 2) + 1;
Expand Down Expand Up @@ -228,19 +242,14 @@ impl<R: Runtime> TensorState<R> {
) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = old_rank - rank;
let padded_strides = if rank_diff > 0 {
shape
.iter()
.take(old_rank as usize)
.map(|a| a.to_u32().unwrap())
.sum::<u32>()
} else {
0
};
let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize;

if rank_diff > 0 {
let padded_strides = shape.iter().map(|a| a.to_u32().unwrap()).sum::<u32>();

for _ in 0..rank_diff {
output.push(padded_strides.to_u32().unwrap());
for _ in 0..rank_diff {
output.push(padded_strides);
}
}

old_rank as usize
Expand All @@ -256,10 +265,12 @@ impl<R: Runtime> TensorState<R> {
fn register_shape<T: ToPrimitive>(shape: &[T], old_rank: Option<u32>, output: &mut Vec<u32>) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = rank - old_rank;
let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize;

for _ in 0..rank_diff {
output.push(1);
if rank_diff > 0 {
for _ in 0..rank_diff {
output.push(1);
}
}

old_rank as usize
Expand Down
31 changes: 17 additions & 14 deletions crates/cubecl-core/src/frontend/element/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle).
pub struct ArrayHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle<R::Server>,
pub length: [usize; 1],
pub(crate) length: [usize; 1],
}

pub enum ArrayArg<'a, R: Runtime> {
Expand Down Expand Up @@ -205,35 +205,38 @@ impl<'a, R: Runtime> ArgSettings<R> for ArrayArg<'a, R> {
impl<'a, R: Runtime> ArrayArg<'a, R> {
/// Create a new array argument.
///
/// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization
/// factor of 1.
pub fn new(handle: &'a cubecl_runtime::server::Handle<R::Server>, length: usize) -> Self {
ArrayArg::Handle {
handle: ArrayHandleRef::new(handle, length),
vectorization_factor: 1,
}
}
/// Create a new array argument specified with its vectorization factor.
pub fn vectorized(
vectorization_factor: u8,
/// # Safety
///
/// Specifying the wrong lenght may lead to out-of-bounds reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle<R::Server>,
length: usize,
vectorization_factor: u8,
) -> Self {
ArrayArg::Handle {
handle: ArrayHandleRef::new(handle, length),
handle: ArrayHandleRef::from_raw_parts(handle, length),
vectorization_factor,
}
}
}

impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
pub fn new(handle: &'a cubecl_runtime::server::Handle<R::Server>, length: usize) -> Self {
/// Create a new array handle reference.
///
/// # Safety
///
/// Specifying the wrong lenght may lead to out-of-bounds reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle<R::Server>,
length: usize,
) -> Self {
Self {
handle,
length: [length],
}
}

/// Return the handle as a tensor instead of an array.
pub fn as_tensor(&self) -> TensorHandleRef<'_, R> {
let shape = &self.length;

Expand Down
56 changes: 37 additions & 19 deletions crates/cubecl-core/src/frontend/element/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,36 @@ impl<C: CubePrimitive> LaunchArg for Tensor<C> {

/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle),
/// the strides and the shape.
#[derive(new)]
pub struct TensorHandleRef<'a, R: Runtime> {
pub handle: &'a cubecl_runtime::server::Handle<R::Server>,
pub strides: &'a [usize],
pub shape: &'a [usize],
}

impl<'a, R: Runtime> TensorHandleRef<'a, R> {
/// Convert the handle into a [tensor argument](TensorArg).
pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
unsafe { TensorArg::from_raw_parts(self.handle, self.strides, self.shape, vectorisation) }
}
/// Create a handle from raw parts.
///
/// # Safety
///
/// If you provide wrong strides or shapes, it might create undefined behavior caused by
/// out-of-bounds reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle<R::Server>,
strides: &'a [usize],
shape: &'a [usize],
) -> Self {
Self {
handle,
strides,
shape,
}
}
}

/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
pub enum TensorArg<'a, R: Runtime> {
/// The tensor is passed with a tensor handle.
Expand All @@ -76,32 +99,27 @@ pub enum TensorArg<'a, R: Runtime> {
}

impl<'a, R: Runtime> TensorArg<'a, R> {
/// Create a new tensor argument.
/// Create a new tensor argument specified with its vectorization factor.
///
/// # Safety
///
/// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization
/// factor of 1.
pub fn new(
/// If you provide wrong strides or shapes, it might create undefined behavior caused by
/// out-of-bound reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle<R::Server>,
strides: &'a [usize],
shape: &'a [usize],
) -> Self {
Self::Handle {
handle: TensorHandleRef::new(handle, strides, shape),
vectorization_factor: 1,
}
}
/// Create a new tensor argument specified with its vectorization factor.
pub fn vectorized(
factor: u8,
handle: &'a cubecl_runtime::server::Handle<R::Server>,
strides: &'a [usize],
shape: &'a [usize],
) -> Self {
Self::Handle {
handle: TensorHandleRef::new(handle, strides, shape),
vectorization_factor: factor,
unsafe {
Self::Handle {
handle: TensorHandleRef::from_raw_parts(handle, strides, shape),
vectorization_factor: factor,
}
}
}

/// Create an alias argument.
pub fn alias(position: usize) -> Self {
Self::Alias {
input_pos: position,
Expand Down
8 changes: 8 additions & 0 deletions crates/cubecl-core/src/id.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use cubecl_runtime::ExecutionMode;
use std::any::{Any, TypeId};
use std::fmt::Display;
use std::hash::{DefaultHasher, Hash, Hasher};
Expand All @@ -8,6 +9,7 @@ use std::sync::Arc;
pub struct KernelId {
type_id: core::any::TypeId,
info: Option<Info>,
mode: Option<ExecutionMode>,
}

impl Display for KernelId {
Expand All @@ -25,6 +27,7 @@ impl KernelId {
Self {
type_id: core::any::TypeId::of::<T>(),
info: None,
mode: None,
}
}

Expand All @@ -39,6 +42,11 @@ impl KernelId {
self.info = Some(Info::new(info));
self
}

/// Set the [execution mode](ExecutionMode).
pub fn mode(&mut self, mode: ExecutionMode) {
self.mode = Some(mode);
}
}

/// Extra information
Expand Down
Loading
Loading