From 4d517c3b40a23ea6956492e0b5b8026f0e704898 Mon Sep 17 00:00:00 2001 From: Guillaume Charifi Date: Sat, 14 Sep 2024 20:45:34 +0200 Subject: [PATCH] Add support for numeric constants. (#112) --- Cargo.toml | 1 + crates/cubecl-core/Cargo.toml | 1 + .../cubecl-core/src/frontend/element/float.rs | 31 ++++- .../cubecl-core/src/frontend/element/int.rs | 11 +- .../src/frontend/element/numeric.rs | 3 + .../cubecl-core/src/frontend/element/uint.rs | 5 +- .../cubecl-core/tests/frontend/constants.rs | 114 ++++++++++++++++++ crates/cubecl-core/tests/frontend/mod.rs | 1 + 8 files changed, 164 insertions(+), 3 deletions(-) create mode 100644 crates/cubecl-core/tests/frontend/constants.rs diff --git a/Cargo.toml b/Cargo.toml index b6c2a6a4f..5afe6409e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ num-traits = { version = "0.2.19", default-features = false, features = [ darling = "0.20.10" ident_case = "1" +paste = "1.0.15" proc-macro2 = "1.0.86" quote = "1.0.36" syn = { version = "2", features = ["full", "extra-traits", "visit-mut"] } diff --git a/crates/cubecl-core/Cargo.toml b/crates/cubecl-core/Cargo.toml index e67d2d0a8..8d2912c58 100644 --- a/crates/cubecl-core/Cargo.toml +++ b/crates/cubecl-core/Cargo.toml @@ -28,6 +28,7 @@ cubecl-macros = { path = "../cubecl-macros", version = "0.2.0" } derive-new = { workspace = true } half = { workspace = true, features = ["bytemuck"] } num-traits = { workspace = true } +paste = { workspace = true } serde = { workspace = true } log = { workspace = true } diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 7de1d7b77..d95692f30 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -39,6 +39,19 @@ pub trait Float: + std::cmp::PartialOrd + std::cmp::PartialEq { + const DIGITS: u32; + const EPSILON: Self; + const INFINITY: Self; + const MANTISSA_DIGITS: u32; + const MAX_10_EXP: i32; + const MAX_EXP: i32; + const MIN_10_EXP: i32; + const MIN_EXP: i32; + const MIN_POSITIVE: Self; + const NAN: Self; + const NEG_INFINITY: Self; + const RADIX: u32; + fn new(val: f32) -> Self; fn vectorized(val: f32, vectorization: u32) -> Self; fn vectorized_empty(vectorization: u32) -> Self; @@ -88,7 +101,10 @@ macro_rules! impl_float { } } - impl Numeric for $primitive {} + impl Numeric for $primitive { + const MAX: Self = $primitive::MAX; + const MIN: Self = $primitive::MIN; + } impl Vectorized for $primitive { fn vectorization_factor(&self) -> u32 { @@ -107,6 +123,19 @@ macro_rules! impl_float { } impl Float for $primitive { + const DIGITS: u32 = $primitive::DIGITS; + const EPSILON: Self = $primitive::EPSILON; + const INFINITY: Self = $primitive::INFINITY; + const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS; + const MAX_10_EXP: i32 = $primitive::MAX_10_EXP; + const MAX_EXP: i32 = $primitive::MAX_EXP; + const MIN_10_EXP: i32 = $primitive::MIN_10_EXP; + const MIN_EXP: i32 = $primitive::MIN_EXP; + const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE; + const NAN: Self = $primitive::NAN; + const NEG_INFINITY: Self = $primitive::NEG_INFINITY; + const RADIX: u32 = $primitive::RADIX; + fn new(val: f32) -> Self { $new(val) } diff --git a/crates/cubecl-core/src/frontend/element/int.rs b/crates/cubecl-core/src/frontend/element/int.rs index a498b86d7..cdd572811 100644 --- a/crates/cubecl-core/src/frontend/element/int.rs +++ b/crates/cubecl-core/src/frontend/element/int.rs @@ -40,6 +40,8 @@ pub trait Int: + std::cmp::PartialOrd + std::cmp::PartialEq { + const BITS: u32; + fn new(val: i64) -> Self; fn vectorized(val: i64, vectorization: u32) -> Self; fn __expand_new(context: &mut CubeContext, val: i64) -> ::ExpandType { @@ -76,7 +78,10 @@ macro_rules! impl_int { } } - impl Numeric for $type {} + impl Numeric for $type { + const MAX: Self = $type::MAX; + const MIN: Self = $type::MIN; + } impl Vectorized for $type { fn vectorization_factor(&self) -> u32 { @@ -95,6 +100,8 @@ macro_rules! impl_int { } impl Int for $type { + const BITS: u32 = $type::BITS; + fn new(val: i64) -> Self { val as $type } @@ -120,6 +127,8 @@ impl_int!(i32, I32); impl_int!(i64, I64); impl Int for u32 { + const BITS: u32 = u32::BITS; + fn new(val: i64) -> Self { val as u32 } diff --git a/crates/cubecl-core/src/frontend/element/numeric.rs b/crates/cubecl-core/src/frontend/element/numeric.rs index 9c19b1c4f..dfb5d565b 100644 --- a/crates/cubecl-core/src/frontend/element/numeric.rs +++ b/crates/cubecl-core/src/frontend/element/numeric.rs @@ -49,6 +49,9 @@ pub trait Numeric: + std::cmp::PartialOrd + std::cmp::PartialEq { + const MAX: Self; + const MIN: Self; + /// Create a new constant numeric. /// /// Note: since this must work for both integer and float diff --git a/crates/cubecl-core/src/frontend/element/uint.rs b/crates/cubecl-core/src/frontend/element/uint.rs index 56283a8f4..d7ce3b21c 100644 --- a/crates/cubecl-core/src/frontend/element/uint.rs +++ b/crates/cubecl-core/src/frontend/element/uint.rs @@ -47,4 +47,7 @@ impl ScalarArgSettings for u32 { } } -impl Numeric for u32 {} +impl Numeric for u32 { + const MAX: Self = u32::MAX; + const MIN: Self = u32::MIN; +} diff --git a/crates/cubecl-core/tests/frontend/constants.rs b/crates/cubecl-core/tests/frontend/constants.rs new file mode 100644 index 000000000..d230a1a1a --- /dev/null +++ b/crates/cubecl-core/tests/frontend/constants.rs @@ -0,0 +1,114 @@ +use half::{bf16, f16}; +use paste::paste; + +use cubecl_core::{self as cubecl, prelude::*}; + +macro_rules! gen_cube { + ($trait:ident, [ $($constant:ident $(| $ret_type:ty)?),* ]) => { + $( + gen_cube!($trait, $constant, $($ret_type)?); + )* + }; + ($trait:ident, $constant:ident,) => { + gen_cube!($trait, $constant, T); + }; + ($trait:ident, $constant:ident, $ret_type:ty) => { + paste! { + gen_cube!([< $trait:lower _ $constant:lower >], $trait, $constant, $ret_type); + } + }; + ($func_name:ident, $trait:ident, $constant:ident, $ret_type:ty) => { + #[cube] + pub fn $func_name() -> $ret_type { + T::$constant + } + }; +} + +macro_rules! gen_tests { + ($trait:ident, [ $($type:ident),* ], $constants:tt) => { + $( + gen_tests!($trait, $type, $constants); + )* + }; + ($trait:ident, $type:ident, [ $($constant:ident $(| $ret_type:ty)?),* ]) => { + $( + gen_tests!($trait, $type, $constant, $($ret_type)?); + )* + }; + ($trait:ident, $type:ident, $constant:ident,) => { + gen_tests!($trait, $type, $constant, $type); + }; + ($trait:ident, $type:ident, $constant:ident, $ret_type:ty) => { + paste! { + gen_tests!([< cube_ $trait:lower _ $constant:lower _ $type _test >], [< $trait:lower _ $constant:lower >], $type, $constant, $ret_type); + } + }; + ($test_name:ident, $func_name:ident, $type:ty, $constant:ident, $ret_type:ty) => { + #[test] + fn $test_name() { + let mut context = CubeContext::root(); + $func_name::expand::<$type>(&mut context); + let scope = context.into_scope(); + + let mut scope1 = CubeContext::root().into_scope(); + let item = Item::new(<$ret_type>::as_elem()); + scope1.create_with_value(<$type>::$constant, item); + + assert_eq!( + format!("{:?}", scope.operations), + format!("{:?}", scope1.operations) + ); + } + }; +} + +gen_cube!(Numeric, [MAX, MIN]); +gen_cube!(Int, [BITS | u32]); +gen_cube!( + Float, + [ + DIGITS | u32, + EPSILON, + INFINITY, + MANTISSA_DIGITS | u32, + MAX_10_EXP | i32, + MAX_EXP | i32, + MIN_10_EXP | i32, + MIN_EXP | i32, + MIN_POSITIVE, + NAN, + NEG_INFINITY, + RADIX | u32 + ] +); + +mod tests { + use super::*; + use cubecl_core::{ + frontend::{CubeContext, CubePrimitive}, + ir::Item, + }; + use pretty_assertions::assert_eq; + + gen_tests!(Numeric, [bf16, f16, f32, f64, i32, i64, u32], [MAX, MIN]); + gen_tests!(Int, [i32, i64, u32], [BITS | u32]); + gen_tests!( + Float, + [bf16, f16, f32, f64], + [ + DIGITS | u32, + EPSILON, + INFINITY, + MANTISSA_DIGITS | u32, + MAX_10_EXP | i32, + MAX_EXP | i32, + MIN_10_EXP | i32, + MIN_EXP | i32, + MIN_POSITIVE, + NAN, + NEG_INFINITY, + RADIX | u32 + ] + ); +} diff --git a/crates/cubecl-core/tests/frontend/mod.rs b/crates/cubecl-core/tests/frontend/mod.rs index d5743ad99..8d4c0b161 100644 --- a/crates/cubecl-core/tests/frontend/mod.rs +++ b/crates/cubecl-core/tests/frontend/mod.rs @@ -3,6 +3,7 @@ mod assign; mod cast_elem; mod cast_kind; mod comptime; +mod constants; mod cube_trait; mod for_loop; mod function_call;