Skip to content

Commit

Permalink
Merge pull request #5 from tracel-ai/feat/literal
Browse files Browse the repository at this point in the history
Feat/literal
  • Loading branch information
nathanielsimard authored Jul 17, 2024
2 parents e0b0589 + a41801b commit 409ef0e
Show file tree
Hide file tree
Showing 28 changed files with 678 additions and 380 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {

#[cube]
fn gelu_scalar<F: Float>(x: F) -> F {
x * (F::new(1.0) + F::erf(x / F::sqrt(F::new(2.0)))) / F::new(2.0)
x * (F::erf(x / F::sqrt(2.0.into())) + 1.0) / 2.0
}

```
Expand Down
10 changes: 6 additions & 4 deletions crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ where
pub fn range_expand<F, S, E>(context: &mut CubeContext, start: S, end: E, unroll: bool, mut func: F)
where
F: FnMut(&mut CubeContext, ExpandElementTyped<UInt>),
S: Into<ExpandElement>,
E: Into<ExpandElement>,
S: Into<ExpandElementTyped<UInt>>,
E: Into<ExpandElementTyped<UInt>>,
{
let start: ExpandElement = start.into();
let end: ExpandElement = end.into();
let start: ExpandElementTyped<UInt> = start.into();
let end: ExpandElementTyped<UInt> = end.into();
let start = start.expand;
let end = end.expand;

if unroll {
let start = match start.deref() {
Expand Down
14 changes: 7 additions & 7 deletions crates/cubecl-core/src/frontend/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,23 @@ impl<C: CubePrimitive> Matrix<C> {
///
/// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
#[allow(unused_variables)]
pub fn new(ident: MatrixIdent, m: u8, n: u8, k: u8, layout: MatrixLayout) -> Self {
pub fn new(ident: MatrixIdent, m: u32, n: u32, k: u32, layout: MatrixLayout) -> Self {
Matrix { _c: PhantomData }
}

pub fn __expand_new(
context: &mut CubeContext,
ident: MatrixIdent,
m: u8,
n: u8,
k: u8,
m: ExpandElementTyped<UInt>,
n: ExpandElementTyped<UInt>,
k: ExpandElementTyped<UInt>,
layout: MatrixLayout,
) -> MatrixExpand {
let elem = context.create_matrix(ir::Matrix {
ident,
m,
n,
k,
m: m.constant().unwrap().as_u32() as u8,
n: n.constant().unwrap().as_u32() as u8,
k: k.constant().unwrap().as_u32() as u8,
elem: C::as_elem(),
layout,
});
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/frontend/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl<T> Comptime<T> {
unexpanded!()
}

/// Executes a closure on the comptime and returns a new comptime containing the value.
pub fn map<R, F: Fn(T) -> R>(_comptime: Self, _closure: F) -> Comptime<R> {
unexpanded!()
}
Expand Down
14 changes: 10 additions & 4 deletions crates/cubecl-core/src/frontend/element/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,21 @@ impl<C: CubePrimitive> ExpandElementTyped<Array<C>> {
) -> ExpandElementTyped<C> {
let factor = vectorization_factor.val;
let var = self.expand.clone();
let mut new_var = context.create_local(Item::vectorized(var.item().elem(), factor as u8));
let new_var = context.create_local(Item::vectorized(var.item().elem(), factor as u8));

if vectorization_factor.val == 1 {
let element = index::expand(context, self.clone(), 0u32);
let element = index::expand(context, self.clone(), ExpandElementTyped::from_lit(0u32));
assign::expand(context, element, new_var.clone());
} else {
for i in 0..factor {
let expand: Self = self.expand.clone().into();
let element = index::expand(context, expand, i);
new_var = index_assign::expand(context, new_var, i, element);
let element = index::expand(context, expand, ExpandElementTyped::from_lit(i));
index_assign::expand::<Array<C>>(
context,
new_var.clone().into(),
ExpandElementTyped::from_lit(i),
element,
);
}
}
new_var.into()
Expand Down
95 changes: 89 additions & 6 deletions crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::marker::PhantomData;

use super::{Bool, Numeric, UInt, Vectorized, F32, F64, I32, I64};
use crate::{
ir::{Operator, Variable, Vectorization},
prelude::{init_expand, CubeContext, KernelBuilder, KernelLauncher},
ir::{ConstantScalarValue, Elem, Item, Operator, Variable, Vectorization},
prelude::{index_assign, init_expand, CubeContext, KernelBuilder, KernelLauncher},
KernelSettings, Runtime,
};
use alloc::rc::Rc;

use super::{UInt, Vectorized};
use std::marker::PhantomData;

/// Types used in a cube function must implement this trait
///
Expand Down Expand Up @@ -124,6 +122,37 @@ pub struct ExpandElementTyped<T: CubeType> {
pub(crate) _type: PhantomData<T>,
}

macro_rules! from_const {
($lit:ty, $ty:ty) => {
impl From<$lit> for ExpandElementTyped<$ty> {
fn from(value: $lit) -> Self {
let variable: Variable = value.into();

ExpandElement::Plain(variable).into()
}
}
};
(val $($lit:ty),*) => {
$(
impl From<$lit> for ExpandElementTyped<UInt> {
fn from(value: $lit) -> Self {
let variable: Variable = value.val.into();

ExpandElement::Plain(variable).into()
}
}
)*
};
}

from_const!(u32, UInt);
from_const!(i64, I64);
from_const!(i32, I32);
from_const!(f64, F64);
from_const!(f32, F32);
from_const!(bool, Bool);
from_const!(val UInt, I32, I64, F32, F64);

pub trait ExpandElementBaseInit: CubeType {
fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement;
}
Expand Down Expand Up @@ -171,7 +200,25 @@ impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
}
}

impl<T: CubeType> ExpandElementTyped<T> {
/// Create an [ExpandElementTyped] from a value that is normaly a literal.
pub fn from_lit<L: Into<Variable>>(lit: L) -> Self {
let variable: Variable = lit.into();

ExpandElementTyped::new(ExpandElement::Plain(variable))
}

/// Get the [ConstantScalarValue] from the variable.
pub fn constant(&self) -> Option<ConstantScalarValue> {
match *self.expand {
Variable::ConstantScalar(val) => Some(val),
_ => None,
}
}
}

impl ExpandElement {
/// If the element can be mutated inplace, potentially reusing the register.
pub fn can_mut(&self) -> bool {
match self {
ExpandElement::Managed(var) => {
Expand Down Expand Up @@ -299,3 +346,39 @@ impl<T: Init> Init for Vec<T> {
self.into_iter().map(|e| e.init(context)).collect()
}
}

/// Create a constant element of the correct type during expansion.
pub(crate) fn __expand_new<C: Numeric>(
_context: &mut CubeContext,
val: ExpandElementTyped<C>,
elem: Elem,
) -> ExpandElementTyped<C> {
ExpandElement::Plain(elem.from_constant(*val.expand)).into()
}

/// Create a vectorized constant element of the correct type during expansion.
pub(crate) fn __expand_vectorized<C: Numeric>(
context: &mut CubeContext,
val: ExpandElementTyped<C>,
vectorization: UInt,
elem: Elem,
) -> ExpandElementTyped<C> {
if vectorization.val == 1 {
__expand_new(context, val, elem)
} else {
let new_var = context.create_local(Item::vectorized(elem, vectorization.val as u8));

for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() {
let element = elem.from_constant(*element.expand);

index_assign::expand::<C>(
context,
new_var.clone().into(),
ExpandElementTyped::from_lit(i),
ExpandElement::Plain(element).into(),
);
}

new_var.into()
}
}
8 changes: 5 additions & 3 deletions crates/cubecl-core/src/frontend/element/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ pub trait BoolOps {
fn new(value: bool) -> bool {
value
}
fn __expand_new(_context: &mut CubeContext, value: bool) -> ExpandElementTyped<bool> {
let var: ExpandElement = value.into();
var.into()
fn __expand_new(
_context: &mut CubeContext,
value: ExpandElementTyped<bool>,
) -> ExpandElementTyped<bool> {
ExpandElement::Plain(Elem::Bool.from_constant(*value.expand)).into()
}
}

Expand Down
82 changes: 34 additions & 48 deletions crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use crate::frontend::{
};
use crate::ir::{ConstantScalarValue, Elem, FloatKind, Item, Variable, Vectorization};

use super::{init_expand_element, LaunchArgExpand, ScalarArgSettings, UInt, Vectorized};
use super::{
init_expand_element, LaunchArgExpand, ScalarArgSettings, UInt, Vectorized, __expand_new,
__expand_vectorized,
};
use crate::compute::{KernelBuilder, KernelLauncher};
use crate::prelude::index_assign;
use crate::{unexpanded, Runtime};
use crate::Runtime;

/// Floating point numbers. Used as input in float kernels
pub trait Float:
Expand All @@ -27,18 +29,35 @@ pub trait Float:
+ Ceil
+ Erf
+ Recip
+ core::ops::Index<UInt, Output = Self>
+ core::ops::IndexMut<UInt, Output = Self>
+ From<f32>
+ core::ops::Add<f32, Output = Self>
+ core::ops::Sub<f32, Output = Self>
+ core::ops::Mul<f32, Output = Self>
+ core::ops::Div<f32, Output = Self>
+ std::ops::AddAssign<f32>
+ std::ops::SubAssign<f32>
+ std::ops::MulAssign<f32>
+ std::ops::DivAssign<f32>
+ std::cmp::PartialOrd<f32>
+ std::cmp::PartialEq<f32>
{
fn new(val: f32) -> Self;
fn vectorized(val: f32, vectorization: UInt) -> Self;
fn vectorized_empty(vectorization: UInt) -> Self;
fn __expand_new(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
fn __expand_new(
context: &mut CubeContext,
val: Self::ExpandType,
) -> <Self as CubeType>::ExpandType {
__expand_new(context, val, Self::as_elem())
}
fn __expand_vectorized(
context: &mut CubeContext,
val: f32,
val: Self::ExpandType,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType;
) -> <Self as CubeType>::ExpandType {
__expand_vectorized(context, val, vectorization, Self::as_elem())
}

fn __expand_vectorized_empty(
context: &mut CubeContext,
vectorization: UInt,
Expand Down Expand Up @@ -81,6 +100,12 @@ macro_rules! impl_float {
type Primitive = $primitive;
}

impl From<u32> for $type {
fn from(val: u32) -> Self {
$type::from_int(val)
}
}

impl ExpandElementBaseInit for $type {
fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
init_expand_element(context, elem)
Expand Down Expand Up @@ -110,37 +135,12 @@ macro_rules! impl_float {
Self::vectorized(0., vectorization)
}

fn __expand_new(
_context: &mut CubeContext,
val: f32,
) -> <Self as CubeType>::ExpandType {
Self::new(val).into_expand()
}

fn __expand_vectorized(
context: &mut CubeContext,
val: f32,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::__expand_new(context, val)
} else {
let mut new_var = context
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() {
new_var = index_assign::expand(context, new_var, i, *element);
}

new_var.into()
}
}

fn __expand_vectorized_empty(
context: &mut CubeContext,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::__expand_new(context, 0.)
Self::__expand_new(context, ExpandElementTyped::from_lit(0.))
} else {
context
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8))
Expand All @@ -149,20 +149,6 @@ macro_rules! impl_float {
}
}

impl core::ops::Index<UInt> for $type {
type Output = Self;

fn index(&self, _index: UInt) -> &Self::Output {
unexpanded!()
}
}

impl core::ops::IndexMut<UInt> for $type {
fn index_mut(&mut self, _index: UInt) -> &mut Self::Output {
unexpanded!()
}
}

impl LaunchArgExpand for $type {
fn expand(
builder: &mut KernelBuilder,
Expand Down
Loading

0 comments on commit 409ef0e

Please sign in to comment.