Skip to content

Commit

Permalink
Propagate kernel name via KernelSettings (#278)
Browse files Browse the repository at this point in the history
Signed-off-by: Torstein Grindvik <[email protected]>
Co-authored-by: Torstein Grindvik <[email protected]>
  • Loading branch information
torsteingrindvik and Torstein Grindvik authored Nov 22, 2024
1 parent 9f39294 commit a0b1971
Show file tree
Hide file tree
Showing 19 changed files with 91 additions and 19 deletions.
10 changes: 10 additions & 0 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct KernelExpansion {
pub inputs: Vec<InputInfo>,
pub outputs: Vec<OutputInfo>,
pub scope: Scope,
pub kernel_name: String,
}

/// Simply indicate the output that can be replaced by the input.
Expand Down Expand Up @@ -55,6 +56,7 @@ pub struct KernelSettings {
vectorization_partial: Vec<VectorizationPartial>,
pub cube_dim: CubeDim,
pub reading_strategy: Vec<(u16, ReadingStrategy)>,
pub kernel_name: String,
}

impl core::fmt::Display for KernelSettings {
Expand Down Expand Up @@ -193,6 +195,13 @@ impl KernelSettings {
self.cube_dim = cube_dim;
self
}

/// Set kernel name.
#[allow(dead_code)]
pub fn kernel_name(mut self, name: &'static str) -> Self {
self.kernel_name = name.to_string();
self
}
}

#[allow(dead_code)]
Expand Down Expand Up @@ -331,6 +340,7 @@ impl KernelIntegrator {
named,
cube_dim: settings.cube_dim,
body: self.expansion.scope,
kernel_name: self.expansion.kernel_name,
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/compute/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ impl KernelBuilder {
scope: self.context.into_scope(),
inputs: self.inputs,
outputs: self.outputs,
kernel_name: settings.kernel_name.clone(),
})
.integrate(settings)
}
Expand Down
38 changes: 35 additions & 3 deletions crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,37 @@ use cubecl_runtime::ExecutionMode;

/// A kernel, compiled in the target language
pub struct CompiledKernel<C: Compiler> {
pub name: Option<&'static str>,
/// The name of the kernel entrypoint.
/// For example
///
/// ```text
/// #[cube(launch)]
/// fn gelu_array<F: Float, R: Runtime>() {}
/// ```
///
/// would have the entrypoint name "gelu_array".
pub entrypoint_name: String,

/// A fully qualified debug name of the kernel.
///
/// For example
///
/// ```text
/// #[cube(launch)]
/// fn gelu_array<F: Float, R: Runtime>() {}
/// ```
///
/// would have a debug name such as
///
/// ```text
/// gelu::gelu_array::GeluArray<
/// cubecl_core::frontend::element::float::F32,
/// cubecl_cuda::runtime::CudaRuntime,
/// >
/// ```
pub debug_name: Option<&'static str>,

/// Source code of the kernel
pub source: String,
/// In-memory representation of the kernel
Expand Down Expand Up @@ -48,7 +78,7 @@ impl<C: Compiler> Display for CompiledKernel<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("\n[START_KERNEL_COMPILATION]")?;

if let Some(name) = self.name {
if let Some(name) = self.debug_name {
if name.len() <= 32 {
f.write_fmt(format_args!("\nname: {name}"))?;
} else {
Expand Down Expand Up @@ -188,12 +218,14 @@ pub struct KernelTask<C: Compiler, K: Kernel> {
impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
let gpu_ir = self.kernel_definition.define();
let kernel_name = gpu_ir.kernel_name.clone();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = C::compile(gpu_ir, mode);
let shared_mem_bytes = lower_level_ir.shared_memory_size();

CompiledKernel {
name: Some(core::any::type_name::<K>()),
entrypoint_name: kernel_name,
debug_name: Some(core::any::type_name::<K>()),
source: lower_level_ir.to_string(),
repr: Some(lower_level_ir),
cube_dim,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/ir/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct KernelDefinition {
pub named: Vec<(String, Binding)>,
pub cube_dim: CubeDim,
pub body: Scope,
pub kernel_name: String,
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ impl<D: Dialect> CppCompiler<D> {
bf16: self.bf16,
f16: self.f16,
items: self.items,
kernel_name: value.kernel_name,
}
}

Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-cpp/src/shared/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub struct ComputeKernel<D: Dialect> {
pub bf16: bool,
pub f16: bool,
pub items: HashSet<super::Item<D>>,
pub kernel_name: String,
}

impl<D: Dialect> CompilerRepresentation for ComputeKernel<D> {
Expand Down Expand Up @@ -120,8 +121,9 @@ struct __align__({alignment}) {item} {{"
f,
"
extern \"C\" __global__ void kernel(
extern \"C\" __global__ void {}(
",
self.kernel_name
)?;

let num_bindings = self.inputs.len() + self.outputs.len() + self.named.len();
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ impl CudaContext {
cudarc::nvrtc::result::get_ptx(program).unwrap()
};

let func_name = CString::new("kernel".to_string()).unwrap();
let func_name = CString::new(kernel_compiled.entrypoint_name).unwrap();
let func = unsafe {
let module =
cudarc::driver::result::module::load_data(ptx.as_ptr() as *const _).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion crates/cubecl-macros/src/generate/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ impl Launch {
.map(|_| quote![__ty: ::core::marker::PhantomData]);
let (compilation_args, args) = self.compilation_args_def();
let info = param_names.clone().into_iter().chain(args.clone());
let sig_name = self.func.sig.name.to_string();

quote! {
#[doc = #kernel_doc]
Expand All @@ -250,7 +251,7 @@ impl Launch {
impl #generics #kernel_name #generic_names #where_clause {
pub fn new(settings: #kernel_settings, #(#compilation_args,)* #(#const_params),*) -> Self {
Self {
settings,
settings: settings.kernel_name(#sig_name),
#(#args,)*
#(#param_names,)*
#phantom_data_init
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-spirv/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ impl<Target: SpirvTarget> SpirvCompiler<Target> {
let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];

let mut target = self.target.clone();
target.set_kernel_name(kernel.kernel_name);

let extensions = target.extensions(self);
self.state.extensions = extensions;

Expand Down
26 changes: 23 additions & 3 deletions crates/cubecl-spirv/src/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,21 @@ pub trait SpirvTarget:
name: String,
index: u32,
) -> Word;
fn set_kernel_name(&mut self, name: impl Into<String>);
}

#[derive(Clone, Default)]
pub struct GLCompute;
#[derive(Clone)]
pub struct GLCompute {
kernel_name: String,
}

impl Default for GLCompute {
fn default() -> Self {
Self {
kernel_name: "main".into(),
}
}
}

impl Debug for GLCompute {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -66,7 +77,12 @@ impl SpirvTarget for GLCompute {
}

b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
b.entry_point(ExecutionModel::GLCompute, main, "main", interface);
b.entry_point(
ExecutionModel::GLCompute,
main,
&self.kernel_name,
interface,
);
b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
}

Expand Down Expand Up @@ -110,4 +126,8 @@ impl SpirvTarget for GLCompute {
fn extensions(&mut self, b: &mut SpirvCompiler<Self>) -> Vec<Word> {
vec![b.ext_inst_import("GLSL.std.450")]
}

fn set_kernel_name(&mut self, name: impl Into<String>) {
self.kernel_name = name.into();
}
}
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/compiler/spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl WgpuCompiler for SpirvCompiler<GLCompute> {
label: None,
layout: layout.as_ref(),
module: &module,
entry_point: "main",
entry_point: &kernel.entrypoint_name,
compilation_options: wgpu::PipelineCompilationOptions {
zero_initialize_workgroup_memory: false,
..Default::default()
Expand Down
3 changes: 2 additions & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl WgpuCompiler for WgslCompiler {
label: None,
layout: layout.as_ref(),
module: &module,
entry_point: "main",
entry_point: &kernel.entrypoint_name,
compilation_options: wgpu::PipelineCompilationOptions {
zero_initialize_workgroup_memory: false,
..Default::default()
Expand Down Expand Up @@ -269,6 +269,7 @@ impl WgslCompiler {
num_workgroups_no_axis: self.num_workgroup_no_axis,
workgroup_id_no_axis: self.workgroup_id_no_axis,
workgroup_size_no_axis: self.workgroup_size_no_axis,
kernel_name: value.kernel_name,
}
}

Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-wgpu/src/compiler/wgsl/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub struct ComputeShader {
pub workgroup_size_no_axis: bool,
pub body: Body,
pub extensions: Vec<Extension>,
pub kernel_name: String,
}

impl Display for ComputeShader {
Expand Down Expand Up @@ -143,9 +144,9 @@ const WORKGROUP_SIZE_Z = {}u;\n",
"
@compute
@workgroup_size({}, {}, {})
fn main(
fn {}(
",
self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z
self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z, self.kernel_name
)?;

if self.global_invocation_id {
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/tests/constant_array.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const WORKGROUP_SIZE_Z = 1u;

@compute
@workgroup_size(16, 16, 1)
fn main(
fn constant_array_kernel(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/tests/plane_elect.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const WORKGROUP_SIZE_Z = 1u;

@compute
@workgroup_size(4, 1, 1)
fn main(
fn kernel_elect(
@builtin(local_invocation_index) local_idx: u32,
) {
let _0 = subgroupElect();
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/tests/plane_sum.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const WORKGROUP_SIZE_Z = 1u;

@compute
@workgroup_size(4, 1, 1)
fn main(
fn kernel_sum(
@builtin(local_invocation_index) local_idx: u32,
) {
let _0 = output_0_global[local_idx];
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/tests/sequence_for_loop.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const WORKGROUP_SIZE_Z = 1u;

@compute
@workgroup_size(16, 16, 1)
fn main(
fn sequence_for_loop_kernel(
@builtin(local_invocation_index) local_idx: u32,
) {
let _0 = local_idx != 0u;
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/tests/slice_assign.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const WORKGROUP_SIZE_Z = 1u;

@compute
@workgroup_size(1, 1, 1)
fn main(
fn slice_assign_kernel(
@builtin(local_invocation_index) local_idx: u32,
) {
let _0 = local_idx == 0u;
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/tests/unary_bench.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const WORKGROUP_SIZE_Z = 1u;

@compute
@workgroup_size(16, 16, 1)
fn main(
fn execute_unary_kernel(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
Expand Down

0 comments on commit a0b1971

Please sign in to comment.