From 0922d15c80e32f435a1a8df1c4602d5e3865db15 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 1 Dec 2023 19:30:02 -0800 Subject: [PATCH] [xla:gpu] Add kernel launch context to arguments packing to be able to query kernel occupancy data Query CUTLASS kernel occupancy for custom GEMMs PiperOrigin-RevId: 587199835 --- .../gpu/kernels/cutlass_gemm_kernel.cu.h | 13 +++++--- .../xla/xla/stream_executor/cuda/BUILD | 3 ++ .../cuda/cuda_command_buffer_test.cc | 3 +- .../xla/stream_executor/cuda/cuda_executor.cc | 8 ++++- .../xla/stream_executor/cuda/cuda_kernel.cc | 25 +++++++++++++-- third_party/xla/xla/stream_executor/gpu/BUILD | 2 ++ .../xla/xla/stream_executor/gpu/gpu_kernel.h | 32 +++++++++++++------ third_party/xla/xla/stream_executor/kernel.cc | 20 ++++++++++++ third_party/xla/xla/stream_executor/kernel.h | 28 +++++++++++++++- .../xla/xla/stream_executor/kernel_spec.h | 3 +- .../stream_executor_internal.h | 7 ++++ 11 files changed, 124 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h index 0979e656c0387e..78e0ed36c53b77 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h @@ -152,7 +152,8 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, using PackedArgs = StatusOr>; - return [=](const se::KernelArgs &args) -> PackedArgs { + return [=](const se::KernelLaunchContext &ctx, + const se::KernelArgs &args) -> PackedArgs { auto *mem_args = Cast(&args); cutlass::Status can_implement = Kernel::can_implement(problem_size); @@ -189,11 +190,15 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size, lda, ldb, ldc, ldc // strides ); - // TODO(ezhulenev): Get number of SMs from a DeviceDescription and calculate - // correct kernel occupancy using GpuRuntime. + // Query kernel API for SM occupancy for the launch dimensions. + TF_ASSIGN_OR_RETURN(int32_t sm_occupancy, + ctx.kernel()->GetMaxOccupiedBlocksPerCore( + ctx.threads(), args.number_of_shared_bytes())); + + // TODO(ezhulenev): Get number of SMs from DeviceDescription. // Convert CUTLASS operation arguments to a device kernel parameters. - Params params(arguments, /*device_sms=*/128, /*sm_occupancy=*/10); + Params params(arguments, /*device_sms=*/128, sm_occupancy); // Optionally set up dynamic slice parameters to allow kernel adjust buffer // pointers passed via `params`. diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index d292bc57ace60b..a75d742f8c4b95 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -441,10 +441,13 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":cuda_driver", + "@com_google_absl//absl/log", "@local_config_cuda//cuda:cuda_headers", "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/gpu:gpu_kernel_header", + "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/platform", + "@local_tsl//tsl/platform:statusor", ]), ) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc index d3064b4bdde443..2c40db007be26b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer_test.cc @@ -114,7 +114,8 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { // Register a kernel with a custom arguments packing function that packs // device memory arguments into a struct with pointers. - MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelArgs& args) { + MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelLaunchContext&, + const KernelArgs& args) { auto bufs = Cast(&args)->device_memory_args(); auto cast = [](auto m) { return reinterpret_cast(m.opaque()); }; return PackKernelArgs(add, internal::Ptrs3{ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 500ac1789e495d..419590c30a08be 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -239,6 +240,10 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, cuda_kernel->gpu_function_ptr())); } + // Update CUDA kernel properties after it was loaded in the CUDA context. + cuda_kernel->set_name(*kernel_name); + cuda_kernel->set_gpu_context(context_); + // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the CUDA API. cuda_kernel->set_arity(spec.arity()); @@ -464,7 +469,8 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, "Kernel is missing a custom arguments packing function for device " "memory arguments array"); - TF_ASSIGN_OR_RETURN(auto packed, pack(*device_mem)); + KernelLaunchContext ctx(&kernel, block_dims, thread_dims); + TF_ASSIGN_OR_RETURN(auto packed, pack(ctx, *device_mem)); return launch(*packed); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc index 2840c0f8165e8f..464f2118128a4b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc @@ -13,7 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/cuda/cuda_kernel.h" +#include +#include + +#include "absl/log/log.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -30,9 +39,21 @@ CUfunc_cache GpuKernel::GetGpuCacheConfig() const { return CU_FUNC_CACHE_PREFER_EQUAL; default: LOG(FATAL) << "Unknown KernelCacheConfig" - << static_cast(preferred_cache_config_); + << static_cast(preferred_cache_config_); } } +tsl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + int32_t threads_per_block = threads.x * threads.y * threads.z; + VLOG(0) << "Get kernel block occupancy: " << name_ + << "; threads_per_block: " << threads_per_block + << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; + + return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, + threads_per_block, + dynamic_shared_memory_bytes); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index ba8f01430baeb9..302c69579fdf10 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -223,10 +223,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gpu_driver_header", + ":gpu_types_header", "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor:stream_executor_internal", "//xla/stream_executor/platform", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h index 09443a23259b58..7f8ea596902133 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h @@ -22,11 +22,17 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ +#include +#include +#include +#include + #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor_internal.h" -#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { @@ -35,10 +41,7 @@ namespace gpu { // KernelInterface. class GpuKernel : public internal::KernelInterface { public: - GpuKernel() - : gpu_function_(nullptr), - arity_(0), - preferred_cache_config_(KernelCacheConfig::kNoPreference) {} + GpuKernel() = default; // Note that the function is unloaded when the module is unloaded, and the // module that the function is contained in is owned by the GpuExecutor. @@ -49,6 +52,9 @@ class GpuKernel : public internal::KernelInterface { void set_arity(unsigned arity) { arity_ = arity; } unsigned Arity() const override { return arity_; } + void set_name(std::string name) { name_ = std::move(name); } + void set_gpu_context(GpuContext* gpu_context) { gpu_context_ = gpu_context; } + // Returns the GpuFunctionHandle value for passing to the CUDA API. GpuFunctionHandle AsGpuFunctionHandle() const { DCHECK(gpu_function_ != nullptr); @@ -79,12 +85,18 @@ class GpuKernel : public internal::KernelInterface { // CUfunc_cache. GpuFuncCachePreference GetGpuCacheConfig() const; + tsl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + private: - GpuFunctionHandle gpu_function_; // Wrapped CUDA kernel handle. - unsigned arity_; // Number of formal parameters the kernel takes. + GpuContext* gpu_context_ = nullptr; // context where kernel is loaded + std::string name_; // kernel name + + GpuFunctionHandle gpu_function_ = nullptr; // wrapped CUDA kernel handle + unsigned arity_ = 0; // number of formal parameters the kernel takes - // Preferred (but not required) cache configuration for this kernel. - KernelCacheConfig preferred_cache_config_; + // Preferred (but not required) cache configuration for this kernel + KernelCacheConfig preferred_cache_config_ = KernelCacheConfig::kNoPreference; }; // Given a platform-independent kernel datatype, returns the (const) internal diff --git a/third_party/xla/xla/stream_executor/kernel.cc b/third_party/xla/xla/stream_executor/kernel.cc index 3b06a32fc31c5c..4bf8f3987df9b6 100644 --- a/third_party/xla/xla/stream_executor/kernel.cc +++ b/third_party/xla/xla/stream_executor/kernel.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/kernel.h" +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_internal.h" #include "tsl/platform/demangle.h" +#include "tsl/platform/statusor.h" namespace stream_executor { @@ -45,6 +47,18 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) { shared_memory_bytes_ = shared_memory_bytes; } +//===----------------------------------------------------------------------===// +// KernelLaunchContext +//===----------------------------------------------------------------------===// + +KernelLaunchContext::KernelLaunchContext(const Kernel *kernel, BlockDim blocks, + ThreadDim threads) + : kernel_(kernel), blocks_(blocks), threads_(threads) {} + +//===----------------------------------------------------------------------===// +// Kernel +//===----------------------------------------------------------------------===// + Kernel::Kernel(Kernel &&from) : parent_(from.parent_), implementation_(std::move(from.implementation_)), @@ -74,6 +88,12 @@ KernelCacheConfig Kernel::GetPreferredCacheConfig() const { return implementation_->GetPreferredCacheConfig(); } +tsl::StatusOr Kernel::GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + return implementation_->GetMaxOccupiedBlocksPerCore( + threads, dynamic_shared_memory_bytes); +} + void Kernel::set_name(absl::string_view name) { name_ = std::string(name); diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index 9077f80955ed8e..1a3ad3b4775cb3 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -88,10 +88,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/launch_dim.h" #include "tsl/platform/statusor.h" namespace stream_executor { +class Kernel; class StreamExecutor; namespace internal { @@ -209,6 +211,25 @@ class KernelArgsPackedArrayBase : public KernelArgs { Kind kind() const final { return Kind::kPackedArray; } }; +//===----------------------------------------------------------------------===// +// KernelLaunchContext +//===----------------------------------------------------------------------===// + +// Properties of a kernel launch that might impact kernel arguments packing. +class KernelLaunchContext { + public: + KernelLaunchContext(const Kernel *kernel, BlockDim blocks, ThreadDim threads); + + const Kernel *kernel() const { return kernel_; } + BlockDim blocks() const { return blocks_; } + ThreadDim threads() const { return threads_; } + + private: + const Kernel *kernel_; + BlockDim blocks_; + ThreadDim threads_; +}; + //===----------------------------------------------------------------------===// // Kernel //===----------------------------------------------------------------------===// @@ -226,7 +247,7 @@ class Kernel { // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = std::function>( - const KernelArgs &args)>; + const KernelLaunchContext &ctx, const KernelArgs &args)>; Kernel(Kernel &&from); @@ -268,6 +289,11 @@ class Kernel { // Gets the preferred cache configuration for a kernel. KernelCacheConfig GetPreferredCacheConfig() const; + // Returns the maximum number of blocks (per multiprocessor) occupied by the + // kernel given the number of threads per block and shared memory size. + tsl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const; + // Sets custom kernels arguments packing function for a kernel. void set_kernel_args_packing(KernelArgsPacking kernel_args_packing) { kernel_args_packing_ = std::move(kernel_args_packing); diff --git a/third_party/xla/xla/stream_executor/kernel_spec.h b/third_party/xla/xla/stream_executor/kernel_spec.h index 6144944306bef2..e49eae4ec322b4 100644 --- a/third_party/xla/xla/stream_executor/kernel_spec.h +++ b/third_party/xla/xla/stream_executor/kernel_spec.h @@ -62,6 +62,7 @@ limitations under the License. namespace stream_executor { class KernelArgs; // defined in kernel.h +class KernelLaunchContext; // defined in kernel.h class KernelArgsPackedArrayBase; // defined in kernel.h // Describes how to load a kernel on a target platform. @@ -262,7 +263,7 @@ class MultiKernelLoaderSpec { // StreamExecutor as a generic `Kernel`. using KernelArgsPacking = std::function>( - const KernelArgs &args)>; + const KernelLaunchContext &ctx, const KernelArgs &args)>; explicit MultiKernelLoaderSpec( size_t arity, KernelArgsPacking kernel_args_packing = nullptr); diff --git a/third_party/xla/xla/stream_executor/stream_executor_internal.h b/third_party/xla/xla/stream_executor/stream_executor_internal.h index 8c17dbdc155603..9931e5b0984afd 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_internal.h +++ b/third_party/xla/xla/stream_executor/stream_executor_internal.h @@ -99,6 +99,13 @@ class KernelInterface { // Gets the preferred cache configuration. virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; + // Returns the maximum number of blocks (per multiprocessor) occupied by the + // kernel given the number of threads per block and shared memory size. + virtual tsl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const { + return absl::UnimplementedError("Not Implemented"); + } + private: KernelInterface(const KernelInterface&) = delete; void operator=(const KernelInterface&) = delete;