Skip to content

Commit

Permalink
Fix macro + cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Sep 10, 2024
1 parent ccd7299 commit 7a86f9a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
16 changes: 14 additions & 2 deletions crates/cubecl-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem::MaybeUninit;

use cubecl_core::{
ir::{Elem, FloatKind},
Feature, FeatureSet, Properties, Runtime,
Expand All @@ -21,6 +23,8 @@ pub struct CudaRuntime;
static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
ComputeRuntime::new();

const MEMORY_OFFSET_ALIGNMENT: u32 = 32;

type Server = CudaServer<DynamicMemoryManagement<CudaStorage>>;

impl Runtime for CudaRuntime {
Expand Down Expand Up @@ -50,8 +54,16 @@ impl Runtime for CudaRuntime {
cudarc::driver::result::stream::StreamKind::NonBlocking,
)
.unwrap();
let max_memory = unsafe {
let mut bytes = MaybeUninit::uninit();
cudarc::driver::sys::lib().cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr);
bytes.assume_init()
};
let storage = CudaStorage::new(stream);
let options = DynamicMemoryManagementOptions::preset(2048 + 512 * 1024 * 1024, 32);
let options = DynamicMemoryManagementOptions::preset(
max_memory / 4, // Max chunk size is max_memory / 4
MEMORY_OFFSET_ALIGNMENT as usize,
);
let memory_management = DynamicMemoryManagement::new(storage, options);
CudaContext::new(memory_management, stream, ctx, arch)
}
Expand All @@ -65,7 +77,7 @@ impl Runtime for CudaRuntime {
MutexComputeChannel::new(server),
features,
Properties {
memory_offset_alignment: 4,
memory_offset_alignment: MEMORY_OFFSET_ALIGNMENT,
},
)
})
Expand Down
1 change: 0 additions & 1 deletion crates/cubecl-macros/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ impl Context {
.scopes
.iter()
.enumerate()
.rev()
.flat_map(|(i, scope)| scope.variables.iter().map(move |it| (i, it)))
.find(|(_, var)| &var.name == name && var.use_count.load(Ordering::Acquire) > 0)
.unwrap_or_else(|| {
Expand Down

0 comments on commit 7a86f9a

Please sign in to comment.