Skip to content

Commit

Permalink
Add compile-time sequence (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jul 25, 2024
1 parent df8ef81 commit f42e605
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 56 deletions.
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ mod context;
mod element;
mod indexation;
mod operation;
mod sequence;
mod subcube;
mod topology;

pub use comptime::*;
pub use context::*;
pub use element::*;
pub use operation::*;
pub use sequence::*;
pub use subcube::*;
pub use topology::*;
80 changes: 32 additions & 48 deletions crates/cubecl-core/src/frontend/operation/assignation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,85 +207,69 @@ pub mod index {
}

pub mod add_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;

use self::ir::Operator;

use super::*;
use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};

pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
pub fn expand<A: CubeType + core::ops::Index<UInt>>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array: ExpandElementTyped<A>,
index: ExpandElementTyped<UInt>,
value: ExpandElementTyped<A::Output>,
) where
A::Output: CubeType + Sized,
{
array_assign_binary_op_expand(context, array, index, value, Operator::Add);
}
}

pub mod sub_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;

use self::ir::Operator;

use super::*;
use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};

pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
pub fn expand<A: CubeType + core::ops::Index<UInt>>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array: ExpandElementTyped<A>,
index: ExpandElementTyped<UInt>,
value: ExpandElementTyped<A::Output>,
) where
A::Output: CubeType + Sized,
{
array_assign_binary_op_expand(context, array, index, value, Operator::Sub);
}
}

pub mod mul_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;

use self::ir::Operator;

use super::*;
use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};

pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
pub fn expand<A: CubeType + core::ops::Index<UInt>>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array: ExpandElementTyped<A>,
index: ExpandElementTyped<UInt>,
value: ExpandElementTyped<A::Output>,
) where
A::Output: CubeType + Sized,
{
array_assign_binary_op_expand(context, array, index, value, Operator::Mul);
}
}

pub mod div_assign_array_op {
use crate::prelude::array_assign_binary_op_expand;

use self::ir::Operator;

use super::*;
use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};

pub fn expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
>(
pub fn expand<A: CubeType + core::ops::Index<UInt>>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
) {
array: ExpandElementTyped<A>,
index: ExpandElementTyped<UInt>,
value: ExpandElementTyped<A::Output>,
) where
A::Output: CubeType + Sized,
{
array_assign_binary_op_expand(context, array, index, value, Operator::Div);
}
}
Expand Down
15 changes: 8 additions & 7 deletions crates/cubecl-core/src/frontend/operation/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::frontend::{CubeContext, ExpandElement};
use crate::ir::{BinaryOperator, Elem, Item, Operator, UnaryOperator, Variable, Vectorization};
use crate::prelude::{CubeType, ExpandElementTyped, UInt};

pub(crate) fn binary_expand<F>(
context: &mut CubeContext,
Expand Down Expand Up @@ -205,17 +206,17 @@ fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization
}

pub fn array_assign_binary_op_expand<
Array: Into<ExpandElement>,
Index: Into<ExpandElement>,
Value: Into<ExpandElement>,
A: CubeType + core::ops::Index<UInt>,
F: Fn(BinaryOperator) -> Operator,
>(
context: &mut CubeContext,
array: Array,
index: Index,
value: Value,
array: ExpandElementTyped<A>,
index: ExpandElementTyped<UInt>,
value: ExpandElementTyped<A::Output>,
func: F,
) {
) where
A::Output: CubeType + Sized,
{
let array: ExpandElement = array.into();
let index: ExpandElement = index.into();
let value: ExpandElement = value.into();
Expand Down
133 changes: 133 additions & 0 deletions crates/cubecl-core/src/frontend/sequence.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use super::{indexation::Index, CubeContext, CubeType, Init};
use crate::unexpanded;
use std::{cell::RefCell, rc::Rc};

/// A sequence of [cube types](CubeType) that is inlined during compilation.
///
/// In other words, it allows you to group a dynamic amount of variables at compile time.
///
/// All methods [push](Sequence::push), [index](Sequence::index) and
/// [into_iter](Sequence::into_iter) are executed _during_ compilation and don't add any overhead
/// on the generated kernel.
pub struct Sequence<T: CubeType> {
values: Vec<T>,
}

impl<T: CubeType> Default for Sequence<T> {
fn default() -> Self {
Self::new()
}
}

impl<T: CubeType> Sequence<T> {
/// Create a new empty sequence.
pub fn new() -> Self {
Self { values: Vec::new() }
}

/// Push a new value into the sequence.
pub fn push(&mut self, value: T) {
self.values.push(value);
}

/// Get the variable at the given position in the sequence.
#[allow(unused_variables, clippy::should_implement_trait)]
pub fn index<I: Index>(&self, index: I) -> &T {
unexpanded!();
}

/// Expand function of [new](Self::new).
pub fn __expand_new(_context: &mut CubeContext) -> SequenceExpand<T> {
SequenceExpand {
values: Rc::new(RefCell::new(Vec::new())),
}
}

/// Expand function of [push](Self::push).
pub fn __expand_push(
context: &mut CubeContext,
expand: &mut SequenceExpand<T>,
value: T::ExpandType,
) {
expand.__expand_push_method(context, value)
}

/// Expand function of [index](Self::index).
pub fn __expand_index<I: Index>(
context: &mut CubeContext,
expand: SequenceExpand<T>,
index: I,
) -> T::ExpandType {
expand.__expand_index_method(context, index)
}
}

/// Expand type of [Sequence].
pub struct SequenceExpand<T: CubeType> {
// We clone the expand type during the compilation phase, but for register reuse, not for
// copying data. To achieve the intended behavior, we have to share the same underlying values.
values: Rc<RefCell<Vec<T::ExpandType>>>,
}

impl<T: CubeType> Init for SequenceExpand<T> {
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
self
}
}

impl<T: CubeType> Clone for SequenceExpand<T> {
fn clone(&self) -> Self {
Self {
values: self.values.clone(),
}
}
}

impl<T: CubeType> IntoIterator for Sequence<T> {
type Item = T;

type IntoIter = <Vec<T> as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.values.into_iter()
}
}

impl<T: CubeType> IntoIterator for SequenceExpand<T> {
type Item = T::ExpandType;

type IntoIter = <Vec<T::ExpandType> as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.values.take().into_iter()
}
}

impl<T: CubeType> CubeType for Sequence<T> {
type ExpandType = SequenceExpand<T>;
}

impl<T: CubeType> SequenceExpand<T> {
/// Expand method of [push](Sequence::push).
pub fn __expand_push_method(&mut self, _context: &mut CubeContext, value: T::ExpandType) {
self.values.borrow_mut().push(value);
}

/// Expand method of [index](Sequence::index).
pub fn __expand_index_method<I: Index>(
&self,
_context: &mut CubeContext,
index: I,
) -> T::ExpandType {
let value = index.value();
let index = match value {
crate::ir::Variable::ConstantScalar(value) => match value {
crate::ir::ConstantScalarValue::Int(val, _) => val as usize,
crate::ir::ConstantScalarValue::UInt(val) => val as usize,
_ => panic!("Only integer types are supported"),
},
_ => panic!("Only constant are supported"),
};
self.values.borrow()[index].clone()
}
}
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/runtime_tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod assign;
pub mod cmma;
pub mod launch;
pub mod sequence;
pub mod slice;
pub mod subcube;
pub mod topology;
Expand All @@ -17,5 +18,6 @@ macro_rules! testgen_all {
cubecl_core::testgen_slice!();
cubecl_core::testgen_assign!();
cubecl_core::testgen_topology!();
cubecl_core::testgen_sequence!();
};
}
84 changes: 84 additions & 0 deletions crates/cubecl-core/src/runtime_tests/sequence.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use crate as cubecl;

use cubecl::prelude::*;

#[cube(launch)]
pub fn sequence_for_loop(output: &mut Array<F32>) {
if UNIT_POS != UInt::new(0) {
return;
}

let mut sequence = Sequence::<F32>::new();
sequence.push(F32::new(1.0));
sequence.push(F32::new(4.0));

for value in sequence {
output[0] += value;
}
}

#[cube(launch)]
pub fn sequence_index(output: &mut Array<F32>) {
if UNIT_POS != UInt::new(0) {
return;
}

let mut sequence = Sequence::<F32>::new();
sequence.push(F32::new(2.0));
sequence.push(F32::new(4.0));

output[0] += *sequence.index(0);
output[0] += *Sequence::index(&sequence, 1);
}

pub fn test_sequence_for_loop<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0]));

sequence_for_loop::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
ArrayArg::new(&handle, 2),
);

let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 5.0);
}

pub fn test_sequence_index<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0]));

sequence_index::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
ArrayArg::new(&handle, 2),
);

let actual = client.read(handle.binding());
let actual = f32::from_bytes(&actual);

assert_eq!(actual[0], 6.0);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_sequence {
() => {
use super::*;

#[test]
fn test_sequence_for_loop() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::sequence::test_sequence_for_loop::<TestRuntime>(client);
}

#[test]
fn test_sequence_index() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::sequence::test_sequence_index::<TestRuntime>(client);
}
};
}
Loading

0 comments on commit f42e605

Please sign in to comment.