Skip to content

Commit

Permalink
Add support for numeric constants. (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
booti386 authored Sep 14, 2024
1 parent a8cbd14 commit 4d517c3
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ num-traits = { version = "0.2.19", default-features = false, features = [

darling = "0.20.10"
ident_case = "1"
paste = "1.0.15"
proc-macro2 = "1.0.86"
quote = "1.0.36"
syn = { version = "2", features = ["full", "extra-traits", "visit-mut"] }
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ cubecl-macros = { path = "../cubecl-macros", version = "0.2.0" }
derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
num-traits = { workspace = true }
paste = { workspace = true }
serde = { workspace = true }

log = { workspace = true }
Expand Down
31 changes: 30 additions & 1 deletion crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ pub trait Float:
+ std::cmp::PartialOrd
+ std::cmp::PartialEq
{
const DIGITS: u32;
const EPSILON: Self;
const INFINITY: Self;
const MANTISSA_DIGITS: u32;
const MAX_10_EXP: i32;
const MAX_EXP: i32;
const MIN_10_EXP: i32;
const MIN_EXP: i32;
const MIN_POSITIVE: Self;
const NAN: Self;
const NEG_INFINITY: Self;
const RADIX: u32;

fn new(val: f32) -> Self;
fn vectorized(val: f32, vectorization: u32) -> Self;
fn vectorized_empty(vectorization: u32) -> Self;
Expand Down Expand Up @@ -88,7 +101,10 @@ macro_rules! impl_float {
}
}

impl Numeric for $primitive {}
impl Numeric for $primitive {
const MAX: Self = $primitive::MAX;
const MIN: Self = $primitive::MIN;
}

impl Vectorized for $primitive {
fn vectorization_factor(&self) -> u32 {
Expand All @@ -107,6 +123,19 @@ macro_rules! impl_float {
}

impl Float for $primitive {
const DIGITS: u32 = $primitive::DIGITS;
const EPSILON: Self = $primitive::EPSILON;
const INFINITY: Self = $primitive::INFINITY;
const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
const MAX_EXP: i32 = $primitive::MAX_EXP;
const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
const MIN_EXP: i32 = $primitive::MIN_EXP;
const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
const NAN: Self = $primitive::NAN;
const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
const RADIX: u32 = $primitive::RADIX;

fn new(val: f32) -> Self {
$new(val)
}
Expand Down
11 changes: 10 additions & 1 deletion crates/cubecl-core/src/frontend/element/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pub trait Int:
+ std::cmp::PartialOrd
+ std::cmp::PartialEq
{
const BITS: u32;

fn new(val: i64) -> Self;
fn vectorized(val: i64, vectorization: u32) -> Self;
fn __expand_new(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
Expand Down Expand Up @@ -76,7 +78,10 @@ macro_rules! impl_int {
}
}

impl Numeric for $type {}
impl Numeric for $type {
const MAX: Self = $type::MAX;
const MIN: Self = $type::MIN;
}

impl Vectorized for $type {
fn vectorization_factor(&self) -> u32 {
Expand All @@ -95,6 +100,8 @@ macro_rules! impl_int {
}

impl Int for $type {
const BITS: u32 = $type::BITS;

fn new(val: i64) -> Self {
val as $type
}
Expand All @@ -120,6 +127,8 @@ impl_int!(i32, I32);
impl_int!(i64, I64);

impl Int for u32 {
const BITS: u32 = u32::BITS;

fn new(val: i64) -> Self {
val as u32
}
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-core/src/frontend/element/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ pub trait Numeric:
+ std::cmp::PartialOrd
+ std::cmp::PartialEq
{
const MAX: Self;
const MIN: Self;

/// Create a new constant numeric.
///
/// Note: since this must work for both integer and float
Expand Down
5 changes: 4 additions & 1 deletion crates/cubecl-core/src/frontend/element/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,7 @@ impl ScalarArgSettings for u32 {
}
}

impl Numeric for u32 {}
impl Numeric for u32 {
const MAX: Self = u32::MAX;
const MIN: Self = u32::MIN;
}
114 changes: 114 additions & 0 deletions crates/cubecl-core/tests/frontend/constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use half::{bf16, f16};
use paste::paste;

use cubecl_core::{self as cubecl, prelude::*};

macro_rules! gen_cube {
($trait:ident, [ $($constant:ident $(| $ret_type:ty)?),* ]) => {
$(
gen_cube!($trait, $constant, $($ret_type)?);
)*
};
($trait:ident, $constant:ident,) => {
gen_cube!($trait, $constant, T);
};
($trait:ident, $constant:ident, $ret_type:ty) => {
paste! {
gen_cube!([< $trait:lower _ $constant:lower >], $trait, $constant, $ret_type);
}
};
($func_name:ident, $trait:ident, $constant:ident, $ret_type:ty) => {
#[cube]
pub fn $func_name<T: $trait>() -> $ret_type {
T::$constant
}
};
}

macro_rules! gen_tests {
($trait:ident, [ $($type:ident),* ], $constants:tt) => {
$(
gen_tests!($trait, $type, $constants);
)*
};
($trait:ident, $type:ident, [ $($constant:ident $(| $ret_type:ty)?),* ]) => {
$(
gen_tests!($trait, $type, $constant, $($ret_type)?);
)*
};
($trait:ident, $type:ident, $constant:ident,) => {
gen_tests!($trait, $type, $constant, $type);
};
($trait:ident, $type:ident, $constant:ident, $ret_type:ty) => {
paste! {
gen_tests!([< cube_ $trait:lower _ $constant:lower _ $type _test >], [< $trait:lower _ $constant:lower >], $type, $constant, $ret_type);
}
};
($test_name:ident, $func_name:ident, $type:ty, $constant:ident, $ret_type:ty) => {
#[test]
fn $test_name() {
let mut context = CubeContext::root();
$func_name::expand::<$type>(&mut context);
let scope = context.into_scope();

let mut scope1 = CubeContext::root().into_scope();
let item = Item::new(<$ret_type>::as_elem());
scope1.create_with_value(<$type>::$constant, item);

assert_eq!(
format!("{:?}", scope.operations),
format!("{:?}", scope1.operations)
);
}
};
}

gen_cube!(Numeric, [MAX, MIN]);
gen_cube!(Int, [BITS | u32]);
gen_cube!(
Float,
[
DIGITS | u32,
EPSILON,
INFINITY,
MANTISSA_DIGITS | u32,
MAX_10_EXP | i32,
MAX_EXP | i32,
MIN_10_EXP | i32,
MIN_EXP | i32,
MIN_POSITIVE,
NAN,
NEG_INFINITY,
RADIX | u32
]
);

mod tests {
use super::*;
use cubecl_core::{
frontend::{CubeContext, CubePrimitive},
ir::Item,
};
use pretty_assertions::assert_eq;

gen_tests!(Numeric, [bf16, f16, f32, f64, i32, i64, u32], [MAX, MIN]);
gen_tests!(Int, [i32, i64, u32], [BITS | u32]);
gen_tests!(
Float,
[bf16, f16, f32, f64],
[
DIGITS | u32,
EPSILON,
INFINITY,
MANTISSA_DIGITS | u32,
MAX_10_EXP | i32,
MAX_EXP | i32,
MIN_10_EXP | i32,
MIN_EXP | i32,
MIN_POSITIVE,
NAN,
NEG_INFINITY,
RADIX | u32
]
);
}
1 change: 1 addition & 0 deletions crates/cubecl-core/tests/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod assign;
mod cast_elem;
mod cast_kind;
mod comptime;
mod constants;
mod cube_trait;
mod for_loop;
mod function_call;
Expand Down

0 comments on commit 4d517c3

Please sign in to comment.