From eedd617e56ed33ede27ff07718d7270f2863cd48 Mon Sep 17 00:00:00 2001 From: Torstein Grindvik Date: Fri, 22 Nov 2024 15:14:00 +0100 Subject: [PATCH] Keep the compiled kernel debug name, add docs Signed-off-by: Torstein Grindvik --- crates/cubecl-core/src/compute/kernel.rs | 47 ++++++++++++++++--- crates/cubecl-cuda/src/compute/server.rs | 2 +- .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 2 +- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs index 4de6fb57a..c2d4b905c 100644 --- a/crates/cubecl-core/src/compute/kernel.rs +++ b/crates/cubecl-core/src/compute/kernel.rs @@ -6,7 +6,37 @@ use cubecl_runtime::ExecutionMode; /// A kernel, compiled in the target language pub struct CompiledKernel { - pub kernel_name: String, + /// The name of the kernel entrypoint. + + /// For example + /// + /// ```no_run + /// #[cube(launch)] + /// fn gelu_array + /// ``` + /// + /// would have the entrypoint name "gelu_array". + pub entrypoint_name: String, + + /// A fully qualified debug name of the kernel. + /// + /// For example + /// + /// ```no_run + /// #[cube(launch)] + /// fn gelu_array + /// ``` + /// + /// would have a debug name such as + /// + /// ```no_run + /// gelu::gelu_array::GeluArray< + /// cubecl_core::frontend::element::float::F32, + /// cubecl_cuda::runtime::CudaRuntime, + /// > + /// ``` + pub debug_name: Option<&'static str>, + /// Source code of the kernel pub source: String, /// In-memory representation of the kernel @@ -48,11 +78,13 @@ impl Display for CompiledKernel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("\n[START_KERNEL_COMPILATION]")?; - if self.kernel_name.len() <= 32 { - f.write_fmt(format_args!("\nname: {}", self.kernel_name))?; - } else { - let name = format_str(&self.kernel_name, &[('<', '>')], false); - f.write_fmt(format_args!("\nname: {name}"))?; + if let Some(name) = self.debug_name { + if name.len() <= 32 { + f.write_fmt(format_args!("\nname: {name}"))?; + } else { + let name = format_str(name, &[('<', '>')], false); + f.write_fmt(format_args!("\nname: {name}"))?; + } } f.write_fmt(format_args!( @@ -192,12 +224,13 @@ impl CubeTask for KernelTask { let shared_mem_bytes = lower_level_ir.shared_memory_size(); CompiledKernel { + entrypoint_name: kernel_name, + debug_name: Some(core::any::type_name::()), source: lower_level_ir.to_string(), repr: Some(lower_level_ir), cube_dim, shared_mem_bytes, debug_info: None, - kernel_name, } } diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index 09ac5eed0..9687b77e5 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -374,7 +374,7 @@ impl CudaContext { cudarc::nvrtc::result::get_ptx(program).unwrap() }; - let func_name = CString::new(kernel_compiled.kernel_name).unwrap(); + let func_name = CString::new(kernel_compiled.entrypoint_name).unwrap(); let func = unsafe { let module = cudarc::driver::result::module::load_data(ptx.as_ptr() as *const _).unwrap(); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 66d722e67..ce30c04ea 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -139,7 +139,7 @@ impl WgpuCompiler for WgslCompiler { label: None, layout: layout.as_ref(), module: &module, - entry_point: &kernel.kernel_name, + entry_point: &kernel.entrypoint_name, compilation_options: wgpu::PipelineCompilationOptions { zero_initialize_workgroup_memory: false, ..Default::default()