From 7e02e76704ee5d5ac7bf502d4955999db1c37c86 Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Tue, 7 Jan 2025 11:31:41 -0500 Subject: [PATCH] Fix magnitude cuda (#398) * add missing fmt_left * add comments --- crates/cubecl-core/src/frontend/element/base.rs | 2 ++ crates/cubecl-core/src/frontend/operation/base.rs | 4 +--- crates/cubecl-cpp/src/shared/instruction.rs | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index b7fbc858..81e82e36 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -31,6 +31,8 @@ pub trait CubeType { } /// Trait useful for cube types that are also used with comptime. +/// +/// This is used to set a variable as mutable. (Need to be fixed or at least renamed.) pub trait IntoRuntime: CubeType + Sized { /// Make sure a type is actually expanded into its runtime [expand type](CubeType::ExpandType). fn runtime(self) -> Self { diff --git a/crates/cubecl-core/src/frontend/operation/base.rs b/crates/cubecl-core/src/frontend/operation/base.rs index 2c1cf308..915f3a08 100644 --- a/crates/cubecl-core/src/frontend/operation/base.rs +++ b/crates/cubecl-core/src/frontend/operation/base.rs @@ -187,15 +187,13 @@ where if input.can_mut() { return input; } - let input_var: Variable = *input; let item = input.item; - let out = context.create_local_mut(item); + let out = context.create_local_mut(item); // TODO: The mut is safe, but unecessary if the variable is immutable. let out_var = *out; let op = func(input_var); - context.register(Instruction::new(op, out_var)); out diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 2c58b2b1..4c765558 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -764,6 +764,7 @@ impl Magnitude { writeln!(f, "{mag} += {input_i} * {input_i};")?; } + let out = out.fmt_left(); write!(f, "{out} = ")?; Sqrt::format_unary(f, &mag, elem)?; f.write_str(";\n")