Skip to content

Commit

Permalink
Added normalize for scalars on wgpu, for cuda compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
RianGoossens committed Sep 10, 2024
1 parent 7a86f9a commit a0b89eb
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,17 @@ for (var {i}: u32 = {start}; {i} {cmp} {end}; {increment}) {{
)),
Instruction::Negate { input, out } => f.write_fmt(format_args!("{out} = -{input};\n")),
Instruction::Normalize { input, out } => {
f.write_fmt(format_args!("{out} = normalize({input});\n"))
if input.item().vectorization_factor() == 1 {
// We need a check for vectorization factor 1 here, for compatibility with cuda.
// You can almost use sign here, however that does not correctly handle the case for x == 0.0.
// Therefore we use normalize with vec2, as there is no way to use a NaN literal in wgsl.
let vec2_type = Item::Vec2(out.elem());
f.write_fmt(format_args!(
"{out} = normalize({vec2_type}({input}, 0.0)).x;\n"
))
} else {
f.write_fmt(format_args!("{out} = normalize({input});\n"))
}
}
}
}
Expand Down

0 comments on commit a0b89eb

Please sign in to comment.