diff --git a/crates/cubecl-core/src/frontend/mod.rs b/crates/cubecl-core/src/frontend/mod.rs index f7f29c4d4..b2f11c85e 100644 --- a/crates/cubecl-core/src/frontend/mod.rs +++ b/crates/cubecl-core/src/frontend/mod.rs @@ -8,6 +8,7 @@ mod context; mod element; mod indexation; mod operation; +mod sequence; mod subcube; mod topology; @@ -15,5 +16,6 @@ pub use comptime::*; pub use context::*; pub use element::*; pub use operation::*; +pub use sequence::*; pub use subcube::*; pub use topology::*; diff --git a/crates/cubecl-core/src/frontend/operation/assignation.rs b/crates/cubecl-core/src/frontend/operation/assignation.rs index 1594320bb..0f8e05cb1 100644 --- a/crates/cubecl-core/src/frontend/operation/assignation.rs +++ b/crates/cubecl-core/src/frontend/operation/assignation.rs @@ -207,85 +207,69 @@ pub mod index { } pub mod add_assign_array_op { - use crate::prelude::array_assign_binary_op_expand; - use self::ir::Operator; - use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand< - Array: Into, - Index: Into, - Value: Into, - >( + pub fn expand>( context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - ) { + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { array_assign_binary_op_expand(context, array, index, value, Operator::Add); } } pub mod sub_assign_array_op { - use crate::prelude::array_assign_binary_op_expand; - use self::ir::Operator; - use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand< - Array: Into, - Index: Into, - Value: Into, - >( + pub fn expand>( context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - ) { + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { array_assign_binary_op_expand(context, array, index, value, Operator::Sub); } } pub mod mul_assign_array_op { - use crate::prelude::array_assign_binary_op_expand; - use self::ir::Operator; - use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand< - Array: Into, - Index: Into, - Value: Into, - >( + pub fn expand>( context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - ) { + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { array_assign_binary_op_expand(context, array, index, value, Operator::Mul); } } pub mod div_assign_array_op { - use crate::prelude::array_assign_binary_op_expand; - use self::ir::Operator; - use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand< - Array: Into, - Index: Into, - Value: Into, - >( + pub fn expand>( context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - ) { + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { array_assign_binary_op_expand(context, array, index, value, Operator::Div); } } diff --git a/crates/cubecl-core/src/frontend/operation/base.rs b/crates/cubecl-core/src/frontend/operation/base.rs index 4d0c70548..70d071894 100644 --- a/crates/cubecl-core/src/frontend/operation/base.rs +++ b/crates/cubecl-core/src/frontend/operation/base.rs @@ -1,5 +1,6 @@ use crate::frontend::{CubeContext, ExpandElement}; use crate::ir::{BinaryOperator, Elem, Item, Operator, UnaryOperator, Variable, Vectorization}; +use crate::prelude::{CubeType, ExpandElementTyped, UInt}; pub(crate) fn binary_expand( context: &mut CubeContext, @@ -205,17 +206,17 @@ fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization } pub fn array_assign_binary_op_expand< - Array: Into, - Index: Into, - Value: Into, + A: CubeType + core::ops::Index, F: Fn(BinaryOperator) -> Operator, >( context: &mut CubeContext, - array: Array, - index: Index, - value: Value, + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, func: F, -) { +) where + A::Output: CubeType + Sized, +{ let array: ExpandElement = array.into(); let index: ExpandElement = index.into(); let value: ExpandElement = value.into(); diff --git a/crates/cubecl-core/src/frontend/sequence.rs b/crates/cubecl-core/src/frontend/sequence.rs new file mode 100644 index 000000000..f285dd3af --- /dev/null +++ b/crates/cubecl-core/src/frontend/sequence.rs @@ -0,0 +1,133 @@ +use super::{indexation::Index, CubeContext, CubeType, Init}; +use crate::unexpanded; +use std::{cell::RefCell, rc::Rc}; + +/// A sequence of [cube types](CubeType) that is inlined during compilation. +/// +/// In other words, it allows you to group a dynamic amount of variables at compile time. +/// +/// All methods [push](Sequence::push), [index](Sequence::index) and +/// [into_iter](Sequence::into_iter) are executed _during_ compilation and don't add any overhead +/// on the generated kernel. +pub struct Sequence { + values: Vec, +} + +impl Default for Sequence { + fn default() -> Self { + Self::new() + } +} + +impl Sequence { + /// Create a new empty sequence. + pub fn new() -> Self { + Self { values: Vec::new() } + } + + /// Push a new value into the sequence. + pub fn push(&mut self, value: T) { + self.values.push(value); + } + + /// Get the variable at the given position in the sequence. + #[allow(unused_variables, clippy::should_implement_trait)] + pub fn index(&self, index: I) -> &T { + unexpanded!(); + } + + /// Expand function of [new](Self::new). + pub fn __expand_new(_context: &mut CubeContext) -> SequenceExpand { + SequenceExpand { + values: Rc::new(RefCell::new(Vec::new())), + } + } + + /// Expand function of [push](Self::push). + pub fn __expand_push( + context: &mut CubeContext, + expand: &mut SequenceExpand, + value: T::ExpandType, + ) { + expand.__expand_push_method(context, value) + } + + /// Expand function of [index](Self::index). + pub fn __expand_index( + context: &mut CubeContext, + expand: SequenceExpand, + index: I, + ) -> T::ExpandType { + expand.__expand_index_method(context, index) + } +} + +/// Expand type of [Sequence]. +pub struct SequenceExpand { + // We clone the expand type during the compilation phase, but for register reuse, not for + // copying data. To achieve the intended behavior, we have to share the same underlying values. + values: Rc>>, +} + +impl Init for SequenceExpand { + fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { + self + } +} + +impl Clone for SequenceExpand { + fn clone(&self) -> Self { + Self { + values: self.values.clone(), + } + } +} + +impl IntoIterator for Sequence { + type Item = T; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.values.into_iter() + } +} + +impl IntoIterator for SequenceExpand { + type Item = T::ExpandType; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.values.take().into_iter() + } +} + +impl CubeType for Sequence { + type ExpandType = SequenceExpand; +} + +impl SequenceExpand { + /// Expand method of [push](Sequence::push). + pub fn __expand_push_method(&mut self, _context: &mut CubeContext, value: T::ExpandType) { + self.values.borrow_mut().push(value); + } + + /// Expand method of [index](Sequence::index). + pub fn __expand_index_method( + &self, + _context: &mut CubeContext, + index: I, + ) -> T::ExpandType { + let value = index.value(); + let index = match value { + crate::ir::Variable::ConstantScalar(value) => match value { + crate::ir::ConstantScalarValue::Int(val, _) => val as usize, + crate::ir::ConstantScalarValue::UInt(val) => val as usize, + _ => panic!("Only integer types are supported"), + }, + _ => panic!("Only constant are supported"), + }; + self.values.borrow()[index].clone() + } +} diff --git a/crates/cubecl-core/src/runtime_tests/mod.rs b/crates/cubecl-core/src/runtime_tests/mod.rs index 005c5bb3c..c8f9faa9a 100644 --- a/crates/cubecl-core/src/runtime_tests/mod.rs +++ b/crates/cubecl-core/src/runtime_tests/mod.rs @@ -1,6 +1,7 @@ pub mod assign; pub mod cmma; pub mod launch; +pub mod sequence; pub mod slice; pub mod subcube; pub mod topology; @@ -17,5 +18,6 @@ macro_rules! testgen_all { cubecl_core::testgen_slice!(); cubecl_core::testgen_assign!(); cubecl_core::testgen_topology!(); + cubecl_core::testgen_sequence!(); }; } diff --git a/crates/cubecl-core/src/runtime_tests/sequence.rs b/crates/cubecl-core/src/runtime_tests/sequence.rs new file mode 100644 index 000000000..e0fe05573 --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/sequence.rs @@ -0,0 +1,84 @@ +use crate as cubecl; + +use cubecl::prelude::*; + +#[cube(launch)] +pub fn sequence_for_loop(output: &mut Array) { + if UNIT_POS != UInt::new(0) { + return; + } + + let mut sequence = Sequence::::new(); + sequence.push(F32::new(1.0)); + sequence.push(F32::new(4.0)); + + for value in sequence { + output[0] += value; + } +} + +#[cube(launch)] +pub fn sequence_index(output: &mut Array) { + if UNIT_POS != UInt::new(0) { + return; + } + + let mut sequence = Sequence::::new(); + sequence.push(F32::new(2.0)); + sequence.push(F32::new(4.0)); + + output[0] += *sequence.index(0); + output[0] += *Sequence::index(&sequence, 1); +} + +pub fn test_sequence_for_loop(client: ComputeClient) { + let handle = client.create(f32::as_bytes(&[0.0])); + + sequence_for_loop::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::default(), + ArrayArg::new(&handle, 2), + ); + + let actual = client.read(handle.binding()); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual[0], 5.0); +} + +pub fn test_sequence_index(client: ComputeClient) { + let handle = client.create(f32::as_bytes(&[0.0])); + + sequence_index::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::default(), + ArrayArg::new(&handle, 2), + ); + + let actual = client.read(handle.binding()); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual[0], 6.0); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_sequence { + () => { + use super::*; + + #[test] + fn test_sequence_for_loop() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::sequence::test_sequence_for_loop::(client); + } + + #[test] + fn test_sequence_index() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::sequence::test_sequence_index::(client); + } + }; +} diff --git a/crates/cubecl-macros/src/codegen_function/branch.rs b/crates/cubecl-macros/src/codegen_function/branch.rs index 66a45f98d..e74d8e0b5 100644 --- a/crates/cubecl-macros/src/codegen_function/branch.rs +++ b/crates/cubecl-macros/src/codegen_function/branch.rs @@ -80,6 +80,13 @@ pub(crate) fn codegen_for_loop( invalid_for_loop() } } + syn::Expr::Path(pat) => { + let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker); + + quote::quote! { + for #i in #pat #block + } + } _ => invalid_for_loop(), } } diff --git a/crates/cubecl-macros/src/codegen_function/function.rs b/crates/cubecl-macros/src/codegen_function/function.rs index 51514d5de..9626f5548 100644 --- a/crates/cubecl-macros/src/codegen_function/function.rs +++ b/crates/cubecl-macros/src/codegen_function/function.rs @@ -64,7 +64,7 @@ pub(crate) fn codegen_closure( if let Some(ty) = ty { inputs.extend(quote::quote! { - #ident : #ty, + #ident: <#ty as CubeType>::ExpandType, }); } else { inputs.extend(quote::quote! { diff --git a/crates/cubecl-macros/src/codegen_function/operation.rs b/crates/cubecl-macros/src/codegen_function/operation.rs index bc41859dc..ede4aeeb0 100644 --- a/crates/cubecl-macros/src/codegen_function/operation.rs +++ b/crates/cubecl-macros/src/codegen_function/operation.rs @@ -266,6 +266,7 @@ pub(crate) fn codegen_unary( cubecl::frontend::not::expand(context, _inner) } }, + syn::UnOp::Deref(_) => inner, _ => todo!("Codegen: unsupported op {:?}", unary.op), }, CodegenKind::Expand,