From 1197e02141df1a7442f21ff6922c98ec0bba153e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 31 May 2024 20:21:38 -0400 Subject: [PATCH] [Build] Guard against older CUDA versions when building CUTLASS 3.x kernels (#5168) --- csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu | 10 ++++++++-- csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu | 11 ++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 5fd6d8ff20867..531414bc45165 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -1,3 +1,9 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + #include #include @@ -6,8 +12,6 @@ #include #include -// clang-format will break include orders -// clang-format off #include "cutlass/cutlass.h" #include "cute/tensor.hpp" @@ -241,3 +245,5 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, } } } + +#endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index dab73ac6c831e..eb532f2ac7a9b 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -1,5 +1,6 @@ +#include + #include -#include #include void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales); +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); +#endif void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, if (version_num >= 90) { // Hopper + + // Guard against compilation issues for sm90 kernels +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); +#else + cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); +#endif } else if (version_num == 89) { // Ada Lovelace cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);