Skip to content

Commit

Permalink
Fix magnitude cuda (#398)
Browse files Browse the repository at this point in the history
* add missing fmt_left

* add comments
  • Loading branch information
maxtremblay authored Jan 7, 2025
1 parent 169ac37 commit 7e02e76
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub trait CubeType {
}

/// Trait useful for cube types that are also used with comptime.
///
/// This is used to set a variable as mutable. (Need to be fixed or at least renamed.)
pub trait IntoRuntime: CubeType + Sized {
/// Make sure a type is actually expanded into its runtime [expand type](CubeType::ExpandType).
fn runtime(self) -> Self {
Expand Down
4 changes: 1 addition & 3 deletions crates/cubecl-core/src/frontend/operation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,13 @@ where
if input.can_mut() {
return input;
}

let input_var: Variable = *input;
let item = input.item;

let out = context.create_local_mut(item);
let out = context.create_local_mut(item); // TODO: The mut is safe, but unecessary if the variable is immutable.
let out_var = *out;

let op = func(input_var);

context.register(Instruction::new(op, out_var));

out
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ impl<D: Dialect> Magnitude<D> {
writeln!(f, "{mag} += {input_i} * {input_i};")?;
}

let out = out.fmt_left();
write!(f, "{out} = ")?;
Sqrt::format_unary(f, &mag, elem)?;
f.write_str(";\n")
Expand Down

0 comments on commit 7e02e76

Please sign in to comment.