Skip to content

Commit

Permalink
Fix native vec types (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 3, 2024
1 parent b0fa39b commit b009bcc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 59 deletions.
43 changes: 5 additions & 38 deletions crates/cubecl-cuda/src/compiler/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,11 @@ impl Binary for IndexAssign {
let item_rhs = rhs.item();

let format_vec = |f: &mut Formatter<'_>, cast: bool| {
let is_vec_native = item_out.is_vec_native();
f.write_str("{\n")?;
let var = "broadcasted";
f.write_fmt(format_args!("{item_out} {var};\n"))?;
for i in 0..item_out.vectorization {
if is_vec_native {
let char = match i {
0 => 'x',
1 => 'y',
2 => 'z',
3 => 'w',
_ => panic!("Invalid"),
};
if cast {
f.write_fmt(format_args!(
"{var}.{char} = {}({});\n",
item_out.elem,
rhs.index(i)
))?;
} else {
f.write_fmt(format_args!("{var}.{char} = {};\n", rhs.index(i)))?;
}
} else if cast {
if cast {
f.write_fmt(format_args!(
"{var}.i_{i} = {}({});\n",
item_out.elem,
Expand Down Expand Up @@ -254,29 +236,14 @@ impl Binary for Index {
let item_lhs = lhs.item();

let format_vec = |f: &mut Formatter<'_>| {
let is_vec_native = item_out.is_vec_native();
f.write_str("{\n")?;
let var = "broadcasted";
f.write_fmt(format_args!("{item_out} {var};\n"))?;
for i in 0..item_out.vectorization {
if is_vec_native {
let char = match i {
0 => 'x',
1 => 'y',
2 => 'z',
3 => 'w',
_ => panic!("Invalid"),
};
f.write_fmt(format_args!(
"{var}.{char} = {}({lhs}[{rhs}].i_{i});\n",
item_out.elem
))?;
} else {
f.write_fmt(format_args!(
"{var}.i_{i} = {}({lhs}[{rhs}].i_{i});\n",
item_out.elem
))?;
}
f.write_fmt(format_args!(
"{var}.i_{i} = {}({lhs}[{rhs}].i_{i});\n",
item_out.elem
))?;
}
f.write_fmt(format_args!("{out} = {var};\n"))?;
f.write_str("}")?;
Expand Down
24 changes: 7 additions & 17 deletions crates/cubecl-cuda/src/compiler/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ impl Display for Item {
return f.write_fmt(format_args!("{}", self.elem));
}

if self.is_vec_native() {
let elem = self.optimized().elem;
return f.write_fmt(format_args!("{elem}"));
}

return f.write_fmt(format_args!("{}_{}", self.elem, self.vectorization));
}
}
Expand Down Expand Up @@ -397,14 +392,19 @@ impl Display for IndexedVariable {
if self.optimized {
let item = self.var.item();
f.write_fmt(format_args!(
"(reinterpret_cast<{item}*>(&{var}))->i_{}",
"(reinterpret_cast<{item}&>({var})).i_{}",
self.index
))
} else {
f.write_fmt(format_args!("{var}.i_{}", self.index))
}
} else {
f.write_fmt(format_args!("{var}"))
if self.optimized {
let item = self.var.item();
f.write_fmt(format_args!("reinterpret_cast<{item}&>({var})"))
} else {
f.write_fmt(format_args!("{var}"))
}
}
}
}
Expand Down Expand Up @@ -438,16 +438,6 @@ impl Item {
matches!(self.elem, Elem::F162 | Elem::BF162)
}

pub fn is_vec_native(&self) -> bool {
match &self.elem {
Elem::F16 => self.vectorization == 2,
Elem::BF16 => self.vectorization == 2,
Elem::F162 => self.vectorization == 1,
Elem::BF162 => self.vectorization == 1,
_ => false,
}
}

pub fn optimized(&self) -> Item {
if self.vectorization == 1 {
return *self;
Expand Down
4 changes: 0 additions & 4 deletions crates/cubecl-cuda/src/compiler/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ impl Display for ComputeKernel {
f.write_str("typedef unsigned int uint;\n")?;

for item in self.items.iter() {
if item.is_vec_native() {
continue;
}

let elem = item.elem;
let size = item.vectorization;
let alignment = elem.size() * size;
Expand Down

0 comments on commit b009bcc

Please sign in to comment.