Skip to content

Commit

Permalink
[xla:gpu] Add kernel launch context to arguments packing to be able t…
Browse files Browse the repository at this point in the history
…o query kernel occupancy data

Query CUTLASS kernel occupancy for custom GEMMs

PiperOrigin-RevId: 587199835
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Dec 2, 2023
1 parent 3880d85 commit 0922d15
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 20 deletions.
13 changes: 9 additions & 4 deletions third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size,

using PackedArgs = StatusOr<std::unique_ptr<se::KernelArgsPackedArrayBase>>;

return [=](const se::KernelArgs &args) -> PackedArgs {
return [=](const se::KernelLaunchContext &ctx,
const se::KernelArgs &args) -> PackedArgs {
auto *mem_args = Cast<se::KernelArgsDeviceMemoryArray>(&args);

cutlass::Status can_implement = Kernel::can_implement(problem_size);
Expand Down Expand Up @@ -189,11 +190,15 @@ KernelArgsPacking ArgsPacking(cutlass::gemm::GemmCoord problem_size,
lda, ldb, ldc, ldc // strides
);

// TODO(ezhulenev): Get number of SMs from a DeviceDescription and calculate
// correct kernel occupancy using GpuRuntime.
// Query kernel API for SM occupancy for the launch dimensions.
TF_ASSIGN_OR_RETURN(int32_t sm_occupancy,
ctx.kernel()->GetMaxOccupiedBlocksPerCore(
ctx.threads(), args.number_of_shared_bytes()));

// TODO(ezhulenev): Get number of SMs from DeviceDescription.

// Convert CUTLASS operation arguments to a device kernel parameters.
Params params(arguments, /*device_sms=*/128, /*sm_occupancy=*/10);
Params params(arguments, /*device_sms=*/128, sm_occupancy);

// Optionally set up dynamic slice parameters to allow kernel adjust buffer
// pointers passed via `params`.
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,13 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_is_configured([
":cuda_driver",
"@com_google_absl//absl/log",
"@local_config_cuda//cuda:cuda_headers",
"//xla/stream_executor:stream_executor_headers",
"//xla/stream_executor/gpu:gpu_kernel_header",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/platform",
"@local_tsl//tsl/platform:statusor",
]),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) {

// Register a kernel with a custom arguments packing function that packs
// device memory arguments into a struct with pointers.
MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelArgs& args) {
MultiKernelLoaderSpec spec(/*arity=*/1, [&](const KernelLaunchContext&,
const KernelArgs& args) {
auto bufs = Cast<KernelArgsDeviceMemoryArray>(&args)->device_memory_args();
auto cast = [](auto m) { return reinterpret_cast<int32_t*>(m.opaque()); };
return PackKernelArgs(add, internal::Ptrs3<int32_t>{
Expand Down
8 changes: 7 additions & 1 deletion third_party/xla/xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstddef>
#include <cstdint>
#include <memory>
#include <utility>
Expand Down Expand Up @@ -239,6 +240,10 @@ tsl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
cuda_kernel->gpu_function_ptr()));
}

// Update CUDA kernel properties after it was loaded in the CUDA context.
cuda_kernel->set_name(*kernel_name);
cuda_kernel->set_gpu_context(context_);

// We have to trust the kernel loader spec arity because there doesn't appear
// to be a way to reflect on the number of expected arguments w/the CUDA API.
cuda_kernel->set_arity(spec.arity());
Expand Down Expand Up @@ -464,7 +469,8 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
"Kernel is missing a custom arguments packing function for device "
"memory arguments array");

TF_ASSIGN_OR_RETURN(auto packed, pack(*device_mem));
KernelLaunchContext ctx(&kernel, block_dims, thread_dims);
TF_ASSIGN_OR_RETURN(auto packed, pack(ctx, *device_mem));
return launch(*packed);
}

Expand Down
25 changes: 23 additions & 2 deletions third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/cuda_kernel.h"
#include <cstddef>
#include <cstdint>

#include "absl/log/log.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {
namespace gpu {
Expand All @@ -30,9 +39,21 @@ CUfunc_cache GpuKernel::GetGpuCacheConfig() const {
return CU_FUNC_CACHE_PREFER_EQUAL;
default:
LOG(FATAL) << "Unknown KernelCacheConfig"
<< static_cast<int32>(preferred_cache_config_);
<< static_cast<int32_t>(preferred_cache_config_);
}
}

tsl::StatusOr<int32_t> GpuKernel::GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const {
int32_t threads_per_block = threads.x * threads.y * threads.z;
VLOG(0) << "Get kernel block occupancy: " << name_
<< "; threads_per_block: " << threads_per_block
<< "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes;

return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_,
threads_per_block,
dynamic_shared_memory_bytes);
}

} // namespace gpu
} // namespace stream_executor
2 changes: 2 additions & 0 deletions third_party/xla/xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,12 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":gpu_driver_header",
":gpu_types_header",
"//xla/stream_executor:stream_executor_headers",
"//xla/stream_executor:stream_executor_internal",
"//xla/stream_executor/platform",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:statusor",
],
)

Expand Down
32 changes: 22 additions & 10 deletions third_party/xla/xla/stream_executor/gpu/gpu_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ limitations under the License.
#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_
#define XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_

#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>

#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/platform/port.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/stream_executor_internal.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {
namespace gpu {
Expand All @@ -35,10 +41,7 @@ namespace gpu {
// KernelInterface.
class GpuKernel : public internal::KernelInterface {
public:
GpuKernel()
: gpu_function_(nullptr),
arity_(0),
preferred_cache_config_(KernelCacheConfig::kNoPreference) {}
GpuKernel() = default;

// Note that the function is unloaded when the module is unloaded, and the
// module that the function is contained in is owned by the GpuExecutor.
Expand All @@ -49,6 +52,9 @@ class GpuKernel : public internal::KernelInterface {
void set_arity(unsigned arity) { arity_ = arity; }
unsigned Arity() const override { return arity_; }

void set_name(std::string name) { name_ = std::move(name); }
void set_gpu_context(GpuContext* gpu_context) { gpu_context_ = gpu_context; }

// Returns the GpuFunctionHandle value for passing to the CUDA API.
GpuFunctionHandle AsGpuFunctionHandle() const {
DCHECK(gpu_function_ != nullptr);
Expand Down Expand Up @@ -79,12 +85,18 @@ class GpuKernel : public internal::KernelInterface {
// CUfunc_cache.
GpuFuncCachePreference GetGpuCacheConfig() const;

tsl::StatusOr<int32_t> GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const override;

private:
GpuFunctionHandle gpu_function_; // Wrapped CUDA kernel handle.
unsigned arity_; // Number of formal parameters the kernel takes.
GpuContext* gpu_context_ = nullptr; // context where kernel is loaded
std::string name_; // kernel name

GpuFunctionHandle gpu_function_ = nullptr; // wrapped CUDA kernel handle
unsigned arity_ = 0; // number of formal parameters the kernel takes

// Preferred (but not required) cache configuration for this kernel.
KernelCacheConfig preferred_cache_config_;
// Preferred (but not required) cache configuration for this kernel
KernelCacheConfig preferred_cache_config_ = KernelCacheConfig::kNoPreference;
};

// Given a platform-independent kernel datatype, returns the (const) internal
Expand Down
20 changes: 20 additions & 0 deletions third_party/xla/xla/stream_executor/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/stream_executor/kernel.h"

#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
Expand All @@ -26,6 +27,7 @@ limitations under the License.
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/stream_executor_internal.h"
#include "tsl/platform/demangle.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {

Expand All @@ -45,6 +47,18 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) {
shared_memory_bytes_ = shared_memory_bytes;
}

//===----------------------------------------------------------------------===//
// KernelLaunchContext
//===----------------------------------------------------------------------===//

KernelLaunchContext::KernelLaunchContext(const Kernel *kernel, BlockDim blocks,
ThreadDim threads)
: kernel_(kernel), blocks_(blocks), threads_(threads) {}

//===----------------------------------------------------------------------===//
// Kernel
//===----------------------------------------------------------------------===//

Kernel::Kernel(Kernel &&from)
: parent_(from.parent_),
implementation_(std::move(from.implementation_)),
Expand Down Expand Up @@ -74,6 +88,12 @@ KernelCacheConfig Kernel::GetPreferredCacheConfig() const {
return implementation_->GetPreferredCacheConfig();
}

tsl::StatusOr<int32_t> Kernel::GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const {
return implementation_->GetMaxOccupiedBlocksPerCore(
threads, dynamic_shared_memory_bytes);
}

void Kernel::set_name(absl::string_view name) {
name_ = std::string(name);

Expand Down
28 changes: 27 additions & 1 deletion third_party/xla/xla/stream_executor/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/launch_dim.h"
#include "tsl/platform/statusor.h"

namespace stream_executor {

class Kernel;
class StreamExecutor;

namespace internal {
Expand Down Expand Up @@ -209,6 +211,25 @@ class KernelArgsPackedArrayBase : public KernelArgs {
Kind kind() const final { return Kind::kPackedArray; }
};

//===----------------------------------------------------------------------===//
// KernelLaunchContext
//===----------------------------------------------------------------------===//

// Properties of a kernel launch that might impact kernel arguments packing.
class KernelLaunchContext {
public:
KernelLaunchContext(const Kernel *kernel, BlockDim blocks, ThreadDim threads);

const Kernel *kernel() const { return kernel_; }
BlockDim blocks() const { return blocks_; }
ThreadDim threads() const { return threads_; }

private:
const Kernel *kernel_;
BlockDim blocks_;
ThreadDim threads_;
};

//===----------------------------------------------------------------------===//
// Kernel
//===----------------------------------------------------------------------===//
Expand All @@ -226,7 +247,7 @@ class Kernel {
// StreamExecutor as a generic `Kernel`.
using KernelArgsPacking =
std::function<tsl::StatusOr<std::unique_ptr<KernelArgsPackedArrayBase>>(
const KernelArgs &args)>;
const KernelLaunchContext &ctx, const KernelArgs &args)>;

Kernel(Kernel &&from);

Expand Down Expand Up @@ -268,6 +289,11 @@ class Kernel {
// Gets the preferred cache configuration for a kernel.
KernelCacheConfig GetPreferredCacheConfig() const;

// Returns the maximum number of blocks (per multiprocessor) occupied by the
// kernel given the number of threads per block and shared memory size.
tsl::StatusOr<int32_t> GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const;

// Sets custom kernels arguments packing function for a kernel.
void set_kernel_args_packing(KernelArgsPacking kernel_args_packing) {
kernel_args_packing_ = std::move(kernel_args_packing);
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/stream_executor/kernel_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ limitations under the License.
namespace stream_executor {

class KernelArgs; // defined in kernel.h
class KernelLaunchContext; // defined in kernel.h
class KernelArgsPackedArrayBase; // defined in kernel.h

// Describes how to load a kernel on a target platform.
Expand Down Expand Up @@ -262,7 +263,7 @@ class MultiKernelLoaderSpec {
// StreamExecutor as a generic `Kernel`.
using KernelArgsPacking =
std::function<tsl::StatusOr<std::unique_ptr<KernelArgsPackedArrayBase>>(
const KernelArgs &args)>;
const KernelLaunchContext &ctx, const KernelArgs &args)>;

explicit MultiKernelLoaderSpec(
size_t arity, KernelArgsPacking kernel_args_packing = nullptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ class KernelInterface {
// Gets the preferred cache configuration.
virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;

// Returns the maximum number of blocks (per multiprocessor) occupied by the
// kernel given the number of threads per block and shared memory size.
virtual tsl::StatusOr<int32_t> GetMaxOccupiedBlocksPerCore(
ThreadDim threads, size_t dynamic_shared_memory_bytes) const {
return absl::UnimplementedError("Not Implemented");
}

private:
KernelInterface(const KernelInterface&) = delete;
void operator=(const KernelInterface&) = delete;
Expand Down

0 comments on commit 0922d15

Please sign in to comment.