Skip to content

Commit

Permalink
Add support for numeric match at runtime (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Sep 26, 2024
1 parent 8a1e78d commit d509fbd
Show file tree
Hide file tree
Showing 15 changed files with 454 additions and 12 deletions.
109 changes: 108 additions & 1 deletion crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use num_traits::NumCast;

use crate::frontend::{CubeContext, ExpandElement};
use crate::ir::{Branch, If, IfElse, Item, Loop, RangeLoop};
use crate::{
frontend::{CubeContext, ExpandElement},
ir::Switch,
};

use super::{assign, CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric};

Expand Down Expand Up @@ -408,6 +411,110 @@ pub fn if_else_expr_expand<C: CubePrimitive>(
}
}

pub struct SwitchExpand<I: Int> {
value: ExpandElementTyped<I>,
default: CubeContext,
cases: Vec<(ExpandElementTyped<I>, CubeContext)>,
}

impl<I: Int> SwitchExpand<I> {
pub fn case(
mut self,
context: &mut CubeContext,
value: impl Int,
block: impl FnOnce(&mut CubeContext),
) -> Self {
let value = I::from(value).unwrap();
let mut case_child = context.child();
block(&mut case_child);
self.cases.push((value.into(), case_child));
self
}

pub fn finish(self, context: &mut CubeContext) {
let value_var = *self.value.expand;
context.register(Branch::Switch(Switch {
value: value_var,
scope_default: self.default.into_scope(),
cases: self
.cases
.into_iter()
.map(|it| (*it.0.expand, it.1.into_scope()))
.collect(),
}));
}
}

pub fn switch_expand<I: Int>(
context: &mut CubeContext,
value: ExpandElementTyped<I>,
default_block: impl FnOnce(&mut CubeContext),
) -> SwitchExpand<I> {
let mut default_child = context.child();
default_block(&mut default_child);

SwitchExpand {
value,
default: default_child,
cases: Vec::new(),
}
}

pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
value: ExpandElementTyped<I>,
out: ExpandElementTyped<C>,
default: CubeContext,
cases: Vec<(ExpandElementTyped<I>, CubeContext)>,
}

impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
pub fn case(
mut self,
context: &mut CubeContext,
value: impl Int,
block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
) -> Self {
let value = I::from(value).unwrap();
let mut case_child = context.child();
let ret = block(&mut case_child);
assign::expand(&mut case_child, ret, self.out.clone());
self.cases.push((value.into(), case_child));
self
}

pub fn finish(self, context: &mut CubeContext) -> ExpandElementTyped<C> {
let value_var = *self.value.expand;
context.register(Branch::Switch(Switch {
value: value_var,
scope_default: self.default.into_scope(),
cases: self
.cases
.into_iter()
.map(|it| (*it.0.expand, it.1.into_scope()))
.collect(),
}));
self.out
}
}

pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
context: &mut CubeContext,
value: ExpandElementTyped<I>,
default_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
) -> SwitchExpandExpr<I, C> {
let mut default_child = context.child();
let default = default_block(&mut default_child);
let out: ExpandElementTyped<C> = context.create_local_variable(default.expand.item()).into();
assign::expand(&mut default_child, default, out.clone());

SwitchExpandExpr {
value,
out,
default: default_child,
cases: Vec::new(),
}
}

pub fn break_expand(context: &mut CubeContext) {
context.register(Branch::Break);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/element/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ impl<T: CubePrimitive + Clone> Array<T> {
.into()
}

pub fn __expand_vectorized<S: Index>(
pub fn __expand_vectorized(
context: &mut CubeContext,
size: S,
size: ExpandElementTyped<u32>,
vectorization_factor: u32,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
Expand Down
10 changes: 10 additions & 0 deletions crates/cubecl-core/src/ir/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub enum Branch {
If(If),
/// An if else statement.
IfElse(IfElse),
/// A switch statement
Switch(Switch),
/// A range loop.
RangeLoop(RangeLoop),
/// A loop.
Expand All @@ -33,6 +35,14 @@ pub struct IfElse {
pub scope_else: Scope,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct Switch {
pub value: Variable,
pub scope_default: Scope,
pub cases: Vec<(Variable, Scope)>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct RangeLoop {
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/runtime_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod launch;
pub mod sequence;
pub mod slice;
pub mod subcube;
pub mod switch;
pub mod topology;
pub mod unary;

Expand All @@ -21,6 +22,7 @@ macro_rules! testgen_all {
cubecl_core::testgen_cmma!();
cubecl_core::testgen_slice!();
cubecl_core::testgen_assign!();
cubecl_core::testgen_switch!();
cubecl_core::testgen_topology!();
cubecl_core::testgen_sequence!();
cubecl_core::testgen_unary!();
Expand Down
152 changes: 152 additions & 0 deletions crates/cubecl-core/src/runtime_tests/switch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use crate as cubecl;

use cubecl::prelude::*;

#[cube(launch)]
pub fn kernel_switch_simple(output: &mut Array<f32>, case: u32) {
if UNIT_POS == 0 {
match case {
0 => {
output[0] = 1.0;
}
1 => {
output[0] = 3.0;
}
_ => {
output[0] = 5.0;
}
}
}
}

#[cube(launch)]
pub fn kernel_switch_value_expr(output: &mut Array<f32>, case: u32) {
if UNIT_POS == 0 {
let value = match case {
0 => 1.0f32,
1 => 3.0f32,
_ => 5.0f32,
};
output[0] = value;
}
}

#[cube(launch)]
pub fn kernel_switch_or_arm(output: &mut Array<f32>, case: u32) {
if UNIT_POS == 0 {
let value = match case {
0 => 1.0f32,
1 | 2 => 3.0f32,
_ => 5.0f32,
};
output[0] = value;
}
}

pub fn test_switch_statement<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

let vectorization = 2;

kernel_switch_simple::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
unsafe { ArrayArg::from_raw_parts(&handle, 2, vectorization) },
ScalarArg::new(0),
);

let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 1.0);
}

pub fn test_switch_used_as_value<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

let vectorization = 2;

kernel_switch_value_expr::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
unsafe { ArrayArg::from_raw_parts(&handle, 2, vectorization) },
ScalarArg::new(1),
);

let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 3.0);
}

pub fn test_switch_default<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

let vectorization = 2;

kernel_switch_value_expr::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
unsafe { ArrayArg::from_raw_parts(&handle, 2, vectorization) },
ScalarArg::new(5),
);

let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 5.0);
}

pub fn test_switch_or_branch<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));

let vectorization = 2;

kernel_switch_or_arm::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
unsafe { ArrayArg::from_raw_parts(&handle, 2, vectorization) },
ScalarArg::new(2),
);

let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 3.0);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_switch {
() => {
use super::*;

#[test]
fn test_switch_statement() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::switch::test_switch_statement::<TestRuntime>(client);
}

#[test]
fn test_switch_used_as_value() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::switch::test_switch_used_as_value::<TestRuntime>(client);
}

#[test]
fn test_switch_default() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::switch::test_switch_default::<TestRuntime>(client);
}

#[test]
fn test_switch_or_branch() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::switch::test_switch_or_branch::<TestRuntime>(client);
}
};
}
11 changes: 11 additions & 0 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,17 @@ impl CudaCompiler {
instructions_if: self.compile_scope(&mut op.scope_if),
instructions_else: self.compile_scope(&mut op.scope_else),
}),
gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
value: self.compile_variable(op.value),
instructions_default: self.compile_scope(&mut op.scope_default),
instructions_cases: op
.cases
.into_iter()
.map(|(val, mut block)| {
(self.compile_variable(val), self.compile_scope(&mut block))
})
.collect(),
}),
gpu::Branch::Return => instructions.push(Instruction::Return),
gpu::Branch::Break => instructions.push(Instruction::Break),
gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
Expand Down
24 changes: 24 additions & 0 deletions crates/cubecl-cuda/src/compiler/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ pub enum Instruction {
instructions_if: Vec<Self>,
instructions_else: Vec<Self>,
},
Switch {
value: Variable,
instructions_default: Vec<Self>,
instructions_cases: Vec<(Variable, Vec<Self>)>,
},
Slice {
input: Variable,
start: Variable,
Expand Down Expand Up @@ -242,6 +247,25 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
}
f.write_str("}\n")
}
Instruction::Switch {
value,
instructions_default,
instructions_cases,
} => {
f.write_fmt(format_args!("switch({value}) {{\n"))?;
for (value, block) in instructions_cases {
f.write_fmt(format_args!("case {value}:\n{{\n"))?;
for i in block {
i.fmt(f)?;
}
f.write_str("break;\n}\n")?;
}
f.write_str("default:\n{")?;
for i in instructions_default {
i.fmt(f)?;
}
f.write_str("}\n}\n")
}
Instruction::Stride { dim, position, out } => f.write_fmt(format_args!(
"{out} = info[({position} * rank_2) + {dim} + 1];\n"
)),
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-linalg/src/matmul/tests/cmma/test_cases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ pub enum MatmulTest {
MSmallerThanN,
}

impl Into<MatmulTestCase> for MatmulTest {
fn into(self) -> MatmulTestCase {
match self {
impl From<MatmulTest> for MatmulTestCase {
fn from(val: MatmulTest) -> Self {
match val {
MatmulTest::SmallRound => MatmulTestCase {
m: 64,
k: 64,
Expand Down
Loading

0 comments on commit d509fbd

Please sign in to comment.