From 9c9ad4957a70bd5e048d7124fdd000402b5d83bd Mon Sep 17 00:00:00 2001 From: nathaniel Date: Fri, 26 Jul 2024 12:09:11 -0400 Subject: [PATCH] Improve output --- crates/cubecl-core/src/compute/kernel.rs | 11 ++++++-- crates/cubecl-cuda/src/compiler/kernel.rs | 32 ++++++++++++++++++----- crates/cubecl-cuda/src/compute/server.rs | 10 ++++++- crates/cubecl-wgpu/src/compute/server.rs | 4 ++- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs index a91657fc8..065492893 100644 --- a/crates/cubecl-core/src/compute/kernel.rs +++ b/crates/cubecl-core/src/compute/kernel.rs @@ -16,6 +16,7 @@ pub struct CompiledKernel { pub cube_dim: CubeDim, /// The number of bytes used by the share memory pub shared_mem_bytes: usize, + pub lang_tag: Option<&'static str>, } impl Display for CompiledKernel { @@ -36,12 +37,17 @@ impl Display for CompiledKernel { cube_dim: ({}, {}, {}) shared_memory: {} bytes source: -``` +```{} {} ``` ================================= ", - self.cube_dim.x, self.cube_dim.y, self.cube_dim.z, self.shared_mem_bytes, self.source + self.cube_dim.x, + self.cube_dim.y, + self.cube_dim.z, + self.shared_mem_bytes, + self.lang_tag.unwrap_or(""), + self.source )) } } @@ -107,6 +113,7 @@ impl CubeTask for KernelTask { source, cube_dim, shared_mem_bytes, + lang_tag: None, } } diff --git a/crates/cubecl-cuda/src/compiler/kernel.rs b/crates/cubecl-cuda/src/compiler/kernel.rs index c239d154f..a32f8d19d 100644 --- a/crates/cubecl-cuda/src/compiler/kernel.rs +++ b/crates/cubecl-cuda/src/compiler/kernel.rs @@ -1,6 +1,6 @@ use super::{Body, Item}; use cubecl_core::{ir::CubeDim, CompilerRepresentation}; -use std::{collections::HashSet, fmt::Display}; +use std::{collections::HashSet, fmt::Display, io::Write, process::Command}; #[derive(Debug, PartialEq, Eq, Clone)] pub struct Binding { @@ -84,11 +84,7 @@ impl Display for ComputeKernel { f.write_str("using namespace nvcuda;\n")?; } - f.write_str( - " -typedef unsigned int uint; - ", - )?; + f.write_str("typedef unsigned int uint;\n")?; for item in self.items.iter() { if item.is_vec_native() { @@ -155,3 +151,27 @@ extern \"C\" __global__ void kernel( Ok(()) } } + +/// Format C++ code, useful when debugging. +pub(crate) fn format_cpp_code(code: &str) -> Result { + let mut child = Command::new("clang-format") + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .spawn()?; + + { + let stdin = child.stdin.as_mut().expect("Failed to open stdin"); + stdin.write_all(code.as_bytes())?; + } + + let output = child.wait_with_output()?; + + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).into_owned()) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "clang-format failed", + )) + } +} diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index b0aa76bce..cb7d7ec32 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -1,3 +1,5 @@ +use crate::compiler::format_cpp_code; + use super::storage::CudaStorage; use super::CudaResource; use cubecl_common::reader::{reader_from_concrete, Reader}; @@ -196,7 +198,13 @@ impl> CudaContext { arch: i32, logger: &mut DebugLogger, ) { - let kernel_compiled = kernel.compile(); + let mut kernel_compiled = kernel.compile(); + kernel_compiled.lang_tag = Some("cpp"); + + if let Ok(formatted) = format_cpp_code(&kernel_compiled.source) { + kernel_compiled.source = formatted; + } + let shared_mem_bytes = kernel_compiled.shared_mem_bytes; let cube_dim = kernel_compiled.cube_dim; let arch = format!("--gpu-architecture=sm_{}", arch); diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 9d1471239..0db048701 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -105,7 +105,9 @@ where return pipeline.clone(); } - let compile = kernel.compile(); + let mut compile = kernel.compile(); + compile.lang_tag = Some("wgsl"); + let compile = self.logger.debug(compile); let pipeline = self.compile_source(&compile.source);