From a0b89eb8b032abc4c13428dccbd74ea0ffeeb2d6 Mon Sep 17 00:00:00 2001 From: Rian Goossens Date: Wed, 11 Sep 2024 01:07:17 +0200 Subject: [PATCH] Added normalize for scalars on wgpu, for cuda compatibility --- crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 005b6b6f3..80cdebc3a 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -689,7 +689,17 @@ for (var {i}: u32 = {start}; {i} {cmp} {end}; {increment}) {{ )), Instruction::Negate { input, out } => f.write_fmt(format_args!("{out} = -{input};\n")), Instruction::Normalize { input, out } => { - f.write_fmt(format_args!("{out} = normalize({input});\n")) + if input.item().vectorization_factor() == 1 { + // We need a check for vectorization factor 1 here, for compatibility with cuda. + // You can almost use sign here, however that does not correctly handle the case for x == 0.0. + // Therefore we use normalize with vec2, as there is no way to use a NaN literal in wgsl. + let vec2_type = Item::Vec2(out.elem()); + f.write_fmt(format_args!( + "{out} = normalize({vec2_type}({input}, 0.0)).x;\n" + )) + } else { + f.write_fmt(format_args!("{out} = normalize({input});\n")) + } } } }