Skip to content

Commit

Permalink
Fix atomics on cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Sep 12, 2024
1 parent 4af5984 commit 10e838a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
3 changes: 1 addition & 2 deletions crates/cubecl-core/src/ir/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ impl Scope {
pub fn process(&mut self) -> ScopeProcessing {
self.undeclared += self.locals.len() as u16;

let mut variables = Vec::new();
core::mem::swap(&mut self.locals, &mut variables);
let mut variables = core::mem::take(&mut self.locals);

for var in self.matrices.drain(..) {
variables.push(var);
Expand Down
9 changes: 6 additions & 3 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,10 @@ impl CudaCompiler {
}),
gpu::Operator::Index(op) => {
if let ExecutionMode::Checked = self.strategy {
if has_length(&op.lhs) {
// Since atomics must be declared inline (for `wgpu` compatibility), we need to
// disable runtime checks for them. Otherwise the variable would be declared
// inside the `if` scope.
if has_length(&op.lhs) && !op.lhs.item().elem.is_atomic() {
self.compile_procedure(
instructions,
gpu::Procedure::CheckedIndex(gpu::CheckedIndex {
Expand Down Expand Up @@ -749,11 +752,11 @@ impl CudaCompiler {
gpu::IntKind::I64 => panic!("i64 isn't supported yet"),
},
gpu::Elem::AtomicInt(kind) => match kind {
gpu::IntKind::I32 => super::Elem::I32,
gpu::IntKind::I32 => super::Elem::Atomic(super::AtomicKind::I32),
gpu::IntKind::I64 => panic!("atomic<i64> isn't supported yet"),
},
gpu::Elem::UInt => super::Elem::U32,
gpu::Elem::AtomicUInt => super::Elem::U32,
gpu::Elem::AtomicUInt => super::Elem::Atomic(super::AtomicKind::U32),
gpu::Elem::Bool => super::Elem::Bool,
}
}
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-cuda/src/compiler/binary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Component, Variable};
use super::{Component, Elem, Variable};
use std::fmt::{Display, Formatter};

pub trait Binary {
Expand Down Expand Up @@ -259,6 +259,8 @@ impl Binary for Index {
f.write_fmt(format_args!("{out} = {}({lhs}[{rhs}]);\n", item_out.elem))?;
}
Ok(())
} else if let Elem::Atomic(inner) = item_out.elem {
f.write_fmt(format_args!("{inner}* {out} = &{lhs}[{rhs}];\n"))
} else {
f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n"))
}
Expand Down
19 changes: 19 additions & 0 deletions crates/cubecl-cuda/src/compiler/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ pub enum Elem {
I32,
U32,
Bool,
Atomic(AtomicKind),
}

#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum AtomicKind {
I32,
U32,
}

impl Display for AtomicKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AtomicKind::I32 => f.write_str("int"),
AtomicKind::U32 => f.write_str("uint"),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
Expand All @@ -33,6 +49,7 @@ impl Display for Elem {
Elem::I32 => f.write_str("int"),
Elem::U32 => f.write_str("uint"),
Elem::Bool => f.write_str("bool"),
Elem::Atomic(inner) => inner.fmt(f),
}
}
}
Expand Down Expand Up @@ -470,6 +487,8 @@ impl Elem {
Self::I32 => core::mem::size_of::<i32>(),
Self::U32 => core::mem::size_of::<u32>(),
Self::Bool => core::mem::size_of::<bool>(),
Self::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
Self::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
}
}
}

0 comments on commit 10e838a

Please sign in to comment.