Skip to content

Commit

Permalink
Keep the compiled kernel debug name, add docs
Browse files Browse the repository at this point in the history
Signed-off-by: Torstein Grindvik <[email protected]>
  • Loading branch information
Torstein Grindvik committed Nov 22, 2024
1 parent 8ae6788 commit 2e9b711
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 10 deletions.
47 changes: 40 additions & 7 deletions crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,37 @@ use cubecl_runtime::ExecutionMode;

/// A kernel, compiled in the target language
pub struct CompiledKernel<C: Compiler> {
pub kernel_name: String,
/// The name of the kernel entrypoint.
/// For example
///
/// ```no_run
/// #[cube(launch)]
/// fn gelu_array<F: Float, R: Runtime>
/// ```
///
/// would have the entrypoint name "gelu_array".
pub entrypoint_name: String,

/// A fully qualified debug name of the kernel.
///
/// For example
///
/// ```no_run
/// #[cube(launch)]
/// fn gelu_array<F: Float, R: Runtime>
/// ```
///
/// would have a debug name such as
///
/// ```no_run
/// gelu::gelu_array::GeluArray<
/// cubecl_core::frontend::element::float::F32,
/// cubecl_cuda::runtime::CudaRuntime,
/// >
/// ```
pub debug_name: Option<&'static str>,

/// Source code of the kernel
pub source: String,
/// In-memory representation of the kernel
Expand Down Expand Up @@ -48,11 +78,13 @@ impl<C: Compiler> Display for CompiledKernel<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("\n[START_KERNEL_COMPILATION]")?;

if self.kernel_name.len() <= 32 {
f.write_fmt(format_args!("\nname: {}", self.kernel_name))?;
} else {
let name = format_str(&self.kernel_name, &[('<', '>')], false);
f.write_fmt(format_args!("\nname: {name}"))?;
if let Some(name) = self.debug_name {
if name.len() <= 32 {
f.write_fmt(format_args!("\nname: {name}"))?;
} else {
let name = format_str(name, &[('<', '>')], false);
f.write_fmt(format_args!("\nname: {name}"))?;
}
}

f.write_fmt(format_args!(
Expand Down Expand Up @@ -192,12 +224,13 @@ impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
let shared_mem_bytes = lower_level_ir.shared_memory_size();

CompiledKernel {
entrypoint_name: kernel_name,
debug_name: Some(core::any::type_name::<K>()),
source: lower_level_ir.to_string(),
repr: Some(lower_level_ir),
cube_dim,
shared_mem_bytes,
debug_info: None,
kernel_name,
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ impl CudaContext {
cudarc::nvrtc::result::get_ptx(program).unwrap()
};

let func_name = CString::new(kernel_compiled.kernel_name).unwrap();
let func_name = CString::new(kernel_compiled.entrypoint_name).unwrap();
let func = unsafe {
let module =
cudarc::driver::result::module::load_data(ptx.as_ptr() as *const _).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/compiler/spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl WgpuCompiler for SpirvCompiler<GLCompute> {
label: None,
layout: layout.as_ref(),
module: &module,
entry_point: &kernel.kernel_name,
entry_point: &kernel.entrypoint_name,
compilation_options: wgpu::PipelineCompilationOptions {
zero_initialize_workgroup_memory: false,
..Default::default()
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl WgpuCompiler for WgslCompiler {
label: None,
layout: layout.as_ref(),
module: &module,
entry_point: &kernel.kernel_name,
entry_point: &kernel.entrypoint_name,
compilation_options: wgpu::PipelineCompilationOptions {
zero_initialize_workgroup_memory: false,
..Default::default()
Expand Down

0 comments on commit 2e9b711

Please sign in to comment.