diff --git a/crates/cubecl-cuda/src/runtime.rs b/crates/cubecl-cuda/src/runtime.rs index 5895dbdd1..63016923a 100644 --- a/crates/cubecl-cuda/src/runtime.rs +++ b/crates/cubecl-cuda/src/runtime.rs @@ -1,3 +1,5 @@ +use std::mem::MaybeUninit; + use cubecl_core::{ ir::{Elem, FloatKind}, Feature, FeatureSet, Properties, Runtime, @@ -21,6 +23,8 @@ pub struct CudaRuntime; static RUNTIME: ComputeRuntime> = ComputeRuntime::new(); +const MEMORY_OFFSET_ALIGNMENT: u32 = 32; + type Server = CudaServer>; impl Runtime for CudaRuntime { @@ -50,8 +54,16 @@ impl Runtime for CudaRuntime { cudarc::driver::result::stream::StreamKind::NonBlocking, ) .unwrap(); + let max_memory = unsafe { + let mut bytes = MaybeUninit::uninit(); + cudarc::driver::sys::lib().cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr); + bytes.assume_init() + }; let storage = CudaStorage::new(stream); - let options = DynamicMemoryManagementOptions::preset(2048 + 512 * 1024 * 1024, 32); + let options = DynamicMemoryManagementOptions::preset( + max_memory / 4, // Max chunk size is max_memory / 4 + MEMORY_OFFSET_ALIGNMENT as usize, + ); let memory_management = DynamicMemoryManagement::new(storage, options); CudaContext::new(memory_management, stream, ctx, arch) } @@ -65,7 +77,7 @@ impl Runtime for CudaRuntime { MutexComputeChannel::new(server), features, Properties { - memory_offset_alignment: 4, + memory_offset_alignment: MEMORY_OFFSET_ALIGNMENT, }, ) }) diff --git a/crates/cubecl-macros/src/scope.rs b/crates/cubecl-macros/src/scope.rs index 24de9633f..63bb2e878 100644 --- a/crates/cubecl-macros/src/scope.rs +++ b/crates/cubecl-macros/src/scope.rs @@ -172,7 +172,6 @@ impl Context { .scopes .iter() .enumerate() - .rev() .flat_map(|(i, scope)| scope.variables.iter().map(move |it| (i, it))) .find(|(_, var)| &var.name == name && var.use_count.load(Ordering::Acquire) > 0) .unwrap_or_else(|| {