Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/autotune result #399

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cubecl-linalg/src/matmul/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn launch_ref<R: Runtime, EG: Float>(
Ok(())
}
Strategy::Simple => {
simple::launch_ref::<R, EG>(client, lhs, rhs, out);
simple::launch_ref::<R, EG>(client, lhs, rhs, out)?;
Ok(())
}
Strategy::Auto => {
Expand Down
6 changes: 5 additions & 1 deletion crates/cubecl-linalg/src/matmul/kernels/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use cubecl_core::ir::Elem;
use cubecl_core::{ir::Elem, CubeCount};
use std::fmt::Debug;

use crate::matmul::components::InvalidConfigError;
Expand All @@ -11,6 +11,7 @@ pub enum MatmulLaunchError {

pub enum MatmulAvailabilityError {
PlaneDimUnknown,
CubeCountTooBig(CubeCount),
PlaneDimUnsupported {
plane_dim: u32,
},
Expand Down Expand Up @@ -123,6 +124,9 @@ impl Debug for MatmulAvailabilityError {
MatmulAvailabilityError::PlaneOperationsUnavailable => {
writeln!(f, "Plane operations not supported.")
}
MatmulAvailabilityError::CubeCountTooBig(count) => {
writeln!(f, "Cube count too big {count:?}")
}
MatmulAvailabilityError::PlaneDimUnknown => {
writeln!(f, "Plane dimension unknown.")
},
Expand Down
29 changes: 21 additions & 8 deletions crates/cubecl-linalg/src/matmul/kernels/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use cubecl_core as cubecl;

use crate::tensor::{into_contiguous, matrix_layout, MatrixLayout, TensorHandle};

use super::MatmulLaunchError;

#[cube(launch_unchecked)]
fn matmul_kernel<F: Float>(
lhs: &Tensor<Line<F>>,
Expand Down Expand Up @@ -83,21 +85,21 @@ pub fn launch_ref<R: Runtime, E: Float>(
lhs: &TensorHandleRef<'_, R>,
rhs: &TensorHandleRef<'_, R>,
out: &TensorHandleRef<'_, R>,
) {
) -> Result<(), MatmulLaunchError> {
let lhs =
TensorHandle::<R, E>::new(lhs.shape.to_vec(), lhs.strides.to_vec(), lhs.handle.clone());
let rhs =
TensorHandle::<R, E>::new(rhs.shape.to_vec(), rhs.strides.to_vec(), rhs.handle.clone());

launch(client, lhs, rhs, out);
launch(client, lhs, rhs, out)
}

pub fn launch<R: Runtime, E: Float>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandle<R, E>,
rhs: TensorHandle<R, E>,
out: &TensorHandleRef<'_, R>,
) {
) -> Result<(), MatmulLaunchError> {
let (cube_dim_x, cube_dim_y) = (32, 8);
let ndims = lhs.shape.len();
let dim1 = ndims - 1;
Expand Down Expand Up @@ -147,7 +149,7 @@ pub fn launch<R: Runtime, E: Float>(
out.shape,
cube_dim_x,
cube_dim_y,
);
)?;

let vectorization_factor = match lhs.shape[ndims - 1] % 4 == 0 {
true => 4,
Expand All @@ -170,6 +172,8 @@ pub fn launch<R: Runtime, E: Float>(
Some(ndims as u32 - 2),
);
};

Ok(())
}

fn simple_cube_count(
Expand All @@ -178,19 +182,28 @@ fn simple_cube_count(
output_shape: &[usize],
cube_dim_x: usize,
cube_dim_y: usize,
) -> CubeCount {
) -> Result<CubeCount, MatmulLaunchError> {
let ndims = lhs_shape.len();
let num_rows = lhs_shape[ndims - 2];
let num_cols = rhs_shape[ndims - 1];

let cubes_x = f32::ceil(num_rows as f32 / cube_dim_x as f32) as u32;
let cubes_y = f32::ceil(num_cols as f32 / cube_dim_y as f32) as u32;
let mut num_iter = 1;
let mut num_iter = 1u32;

#[allow(clippy::needless_range_loop)]
for i in 0..ndims - 2 {
num_iter *= output_shape[i];
num_iter *= output_shape[i] as u32;
}

let result = CubeCount::Static(cubes_x, cubes_y, num_iter);
let max_cube_count = u16::MAX as u32;

if cubes_x > max_cube_count || cubes_y > max_cube_count || num_iter > max_cube_count {
return Err(MatmulLaunchError::Unavailable(
super::MatmulAvailabilityError::CubeCountTooBig(result),
));
}

CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
Ok(result)
}
2 changes: 1 addition & 1 deletion crates/cubecl-linalg/src/matmul/tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn test_simple<R: Runtime, F: Float + CubeElement + Display>(
let expected = case.matmul_cpu::<R, F>(&lhs, &rhs, &client);

let out: TensorHandle<R, F> = case.empty_out(&client);
simple::launch::<R, F>(&client, lhs, rhs, &out.as_ref());
simple::launch::<R, F>(&client, lhs, rhs, &out.as_ref()).unwrap();

if let Err(e) = assert_equals_approx::<R, F>(&client, out.handle, &expected, 10e-4) {
panic!("{}", e);
Expand Down
8 changes: 6 additions & 2 deletions crates/cubecl-macros/src/generate/autotune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ impl AutotuneOperations {

fn generate_op_impl(&self, name: &Ident, func_name: &Path) -> TokenStream {
let operation = tune_type("AutotuneOperation");
let error = tune_type("AutotuneError");

let key = &self.key;
let (generics, generic_names, where_clause) = self.generics.split_for_impl();
Expand Down Expand Up @@ -386,8 +387,11 @@ impl AutotuneOperations {

quote! {
impl #generics #operation<#output> for #name #generic_names #where_clause {
fn execute(self: Box<Self>) -> #output {
#func_name #turbofish(#(#func_args),*)
fn execute(self: Box<Self>) -> Result<#output, #error> {
#func_name #turbofish(#(#func_args),*).map_err(|err| {
let err: #error = err.into();
err
})
}

fn clone(&self) -> Box<dyn #operation<#output>> {
Expand Down
12 changes: 9 additions & 3 deletions crates/cubecl-runtime/src/tune/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<AK: AutotuneKey + 'static, ID: Hash + PartialEq + Eq + Clone + Display> Loc
if let Some(tuner) = map.get(id) {
if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
let op = autotune_operation_set.fastest(fastest_index);
return op.execute();
return op.execute().expect("Should run when selected by autotune.");
}
}
}
Expand Down Expand Up @@ -91,7 +91,10 @@ impl<AK: AutotuneKey + 'static, ID: Hash + PartialEq + Eq + Clone + Display> Loc

match fastest {
TuneCacheResult::Hit { fastest_index } => {
return autotune_operation_set.fastest(fastest_index).execute();
return autotune_operation_set
.fastest(fastest_index)
.execute()
.expect("Should run when selected by autotune.");
}
TuneCacheResult::Miss => {
// We don't know the results yet, start autotuning.
Expand Down Expand Up @@ -147,6 +150,9 @@ impl<AK: AutotuneKey + 'static, ID: Hash + PartialEq + Eq + Clone + Display> Loc
}
};

autotune_operation_set.fastest(fastest).execute()
autotune_operation_set
.fastest(fastest)
.execute()
.expect("Should run when selected by autotune.")
}
}
4 changes: 3 additions & 1 deletion crates/cubecl-runtime/src/tune/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use alloc::vec::Vec;
use core::fmt::{Debug, Display};
use core::hash::Hash;

use super::AutotuneError;

/// Default checksum for an operation set
#[cfg(autotune_persistent_cache)]
pub fn compute_checksum<Out: Send + 'static>(
Expand Down Expand Up @@ -46,7 +48,7 @@ pub trait AutotuneOperationSet<K: Send + 'static, Output: Send + 'static = ()>:
/// Contains operation to run and inputs on which to run it
pub trait AutotuneOperation<Output: Send + 'static = ()>: Send + core::fmt::Debug {
/// Runs the operation
fn execute(self: Box<Self>) -> Output;
fn execute(self: Box<Self>) -> Result<Output, AutotuneError>;

/// The name of the operation.
fn name(&self) -> &str {
Expand Down
15 changes: 10 additions & 5 deletions crates/cubecl-runtime/src/tune/tune_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::server::ComputeServer;
#[cfg(feature = "std")]
use cubecl_common::benchmark::{BenchmarkDurations, TimingMethod};

use super::AutotuneError;
use super::AutotuneOperation;
use alloc::boxed::Box;

Expand All @@ -25,12 +26,13 @@ impl<S: ComputeServer + 'static, C: ComputeChannel<S> + 'static, Out: Send + 'st
{
/// Benchmark how long this operation takes for a number of samples.
#[cfg(feature = "std")]
pub async fn sample_durations(self) -> BenchmarkDurations {
pub async fn sample_durations(self) -> Result<BenchmarkDurations, AutotuneError> {
let operation = self.operation.clone();

// If the inner operation need autotuning as well, we need to call it before.
let _ = self.client.sync().await;
operation.clone().execute();
operation.clone().execute()?;

let _ = self.client.sync().await;

let client = self.client.clone();
Expand All @@ -42,7 +44,10 @@ impl<S: ComputeServer + 'static, C: ComputeChannel<S> + 'static, Out: Send + 'st
let mut durations = Vec::with_capacity(num_samples);

for _ in 0..num_samples {
operation.clone().execute();
operation
.clone()
.execute()
.expect("Should not fail when previsously tried during the warmup.");
// For benchmarks - we need to wait for all tasks to complete before returning.
let duration = match client.sync_elapsed().await {
Ok(val) => val,
Expand All @@ -64,9 +69,9 @@ impl<S: ComputeServer + 'static, C: ComputeChannel<S> + 'static, Out: Send + 'st
})
.await;

BenchmarkDurations {
Ok(BenchmarkDurations {
timing_method: TimingMethod::DeviceOnly,
durations,
}
})
}
}
33 changes: 27 additions & 6 deletions crates/cubecl-runtime/src/tune/tuner.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use async_channel::{Receiver, Sender};
use cubecl_common::future;

use core::any::Any;
use core::future::Future;
use core::mem::ManuallyDrop;
use cubecl_common::stub::Duration;
Expand Down Expand Up @@ -40,9 +41,18 @@ enum AutotuneMessage<K> {
}

/// Error from running autotune.
#[derive(Debug)]
pub enum AutotuneError {
/// An unknown error happened.
Unknown(String),
/// An error catched with panic unwind.
PanicUnwind(ManuallyDrop<Box<dyn Any + Send>>),
}

impl From<String> for AutotuneError {
fn from(value: String) -> Self {
Self::Unknown(value)
}
}

#[allow(clippy::new_without_default)]
Expand Down Expand Up @@ -154,10 +164,15 @@ impl<K: AutotuneKey> Tuner<K> {
let sample_fut = future::catch_unwind(sample_fut);
let result = sample_fut.await;

let result = result.map_err(|e| {
log::warn!("Caught error while benchmarking, falling back to next operation.");
ManuallyDrop::new(e)
});
let result = match result {
Ok(result) => result,
Err(err) => {
log::warn!(
"Caught unknown error while benchmarking, falling back to next operation."
);
Err(AutotuneError::PanicUnwind(ManuallyDrop::new(err)))
}
};

let result = result.map(|durations| {
log::info!("Name: {name} => {}", durations);
Expand All @@ -167,11 +182,17 @@ impl<K: AutotuneKey> Tuner<K> {
bench_results.push(result);
}

// // Panic if all tuners panicked.
// Panic if all tuners panicked.
#[cfg(all(feature = "std", not(target_family = "wasm")))]
if bench_results.iter().all(|result| result.is_err()) {
let first_error = bench_results.into_iter().next().unwrap().err().unwrap();
resume_unwind(ManuallyDrop::into_inner(first_error));

match first_error {
AutotuneError::Unknown(reason) => panic!("{reason}"),
AutotuneError::PanicUnwind(err) => {
resume_unwind(ManuallyDrop::into_inner(err));
}
}
}

// Finds the fastest operation (by the median time).
Expand Down
6 changes: 4 additions & 2 deletions crates/cubecl-runtime/tests/dummy/tune/autotune_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use cubecl_runtime::{
client::ComputeClient,
server::{Binding, CubeCount},
tune::AutotuneOperation,
tune::{AutotuneError, AutotuneOperation},
};
use derive_new::new;

Expand All @@ -21,12 +21,14 @@ pub struct OneKernelAutotuneOperation {

impl AutotuneOperation for OneKernelAutotuneOperation {
/// Executes the operation on given bindings and server, with the additional parameters
fn execute(self: Box<Self>) {
fn execute(self: Box<Self>) -> Result<(), AutotuneError> {
self.client.execute(
self.kernel.clone(),
CubeCount::Static(1, 1, 1),
self.bindings,
);

Ok(())
}

fn clone(&self) -> Box<dyn AutotuneOperation> {
Expand Down