diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index ee801e16573d4..dbb72e8bbd3f5 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -8,6 +8,10 @@ #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp" + +using namespace vllm; + /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). @@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales); } } @@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales); } } @@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { assert(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } } else { @@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_fp8_dispatch< - cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_fp8_dispatch( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales); } } @@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index 037235e52d7ac..d03242f44ab1d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,6 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "epilogue/broadcast_load_epilogue_c2x.hpp" #include "common.hpp" // clang-format on @@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - template - using ColOrScalarLoad = - cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = - cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using RowOrZeroLoad = - cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - // it would technically work but no use case as data_ptr is never nullptr - static_assert(!std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - static_assert(std::is_same_v>); - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : protected ScaledEpilogueBase { - protected: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 84e1f367c8722..33581a63d4c3d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,7 +1,291 @@ -#include +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + #include + +#include + +#include +#include +#include + #include "cutlass/cutlass.h" -#include "scaled_mm_c3x.cuh" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; +using namespace vllm; + +/* + This file defines quantized GEMM operations using the CUTLASS 3.x API, for + NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogue functions can be defined to post-process the output before it is + written to GPU memory. + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace { + +// A wrapper for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { + #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); + #endif + } +}; +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Epilogue = Epilogue_; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +template typename Epilogue> +struct sm90_fp8_config_default { + // M in (128, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_default { + // For M > 128 and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M128 { + // For M in (64, 128] and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M64 { + // For M in (32, 64] and any N + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NBig { + // For M in [1, 32] and N >= 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NSmall { + // For M in [1, 32] and N < 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +} // namespace template typename Epilogue, @@ -139,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == c.dtype(), "currently bias dtype must match output dtype ", c.dtype()); - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( c, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm90_epilogue(c, a, b, a_scales, - b_scales); + return cutlass_scaled_mm_sm90_epilogue( + c, a, b, a_scales, b_scales); } } @@ -158,83 +442,12 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } -// hyper-parameter sweep kernels - -void cutlass_scaled_mm_sm90_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - c10::optional const& bias) { - assert(!bias); - - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using AccType = float; - - if (out.dtype() == torch::kBFloat16) { - using Cutlass3xGemm = - cutlass_3x_gemm; - - return cutlass_gemm_caller(out, a, b, a_scales, b_scales); - - } else { - TORCH_CHECK(out.dtype() == torch::kFloat16); - - using Cutlass3xGemm = - cutlass_3x_gemm; - - return cutlass_gemm_caller(out, a, b, a_scales, b_scales); - } -} - -void cutlass_simple_gemm_sm90_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b) { - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using AccType = float; - - if (out.dtype() == torch::kBFloat16) { - using Cutlass3xGemm = - cutlass_3x_simple_gemm; - - return cutlass_simple_gemm_caller(out, a, b); - - } else { - TORCH_CHECK(out.dtype() == torch::kFloat16); - - using Cutlass3xGemm = - cutlass_3x_simple_gemm; - - return cutlass_simple_gemm_caller(out, a, b); - } -} +#endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh deleted file mode 100644 index 92b35c394eeb8..0000000000000 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh +++ /dev/null @@ -1,786 +0,0 @@ -#pragma once - -// clang-format will break include orders -// clang-format off -#include - -#if defined CUDA_VERSION && CUDA_VERSION >= 12000 - -#include - -#include - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/kernel/tile_scheduler_params.h" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "epilogue/broadcast_load_epilogue_c3x.hpp" -#include "common.hpp" -// clang-format on - -using namespace cute; - -/* - This file defines quantized GEMM operations using the CUTLASS 3.x API, for - NVIDIA GPUs with sm90a (Hopper) or later. - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm90EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. -*/ - -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); - #endif - } -}; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - template - using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; - - // Don't want to support nullptr by default - template - using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // Don't want to support nullptr by default - template - using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - static_assert(!std::is_same_v> && - !std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - static_assert(std::is_same_v> || - std::is_same_v>); - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch.scaled_mm_. - - A and B may be both either int8 or fp8_e4m3. A can be - quantized per-tensor or per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; - -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule, typename AccType, - typename TileSchedule = cutlass::gemm::PersistentScheduler, - GemmUniversalMode Mode_ = GemmUniversalMode::kGemm> -struct cutlass_3x_gemm { - static const GemmUniversalMode Mode = Mode_; - using ElementAB = ElementAB_; - using ElementD = ElementD_; - - using ElementAcc = AccType; - - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; - - using StrideD = Stride, Int<0>>; - using ElementC = void; - using StrideC = StrideD; - - using EVTCompute = typename Epilogue::EVTCompute; - - static constexpr int AlignmentA = - 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentB = - 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentCD = - 128 / cutlass::sizeof_bits::value; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, - AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - // clang-format off - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, AlignmentA, - ElementAB, cutlass::layout::ColumnMajor, AlignmentB, - ElementAcc, TileShape, ClusterShape, - Stages, - KernelSchedule>::CollectiveOp; - // clang-format on - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - TileSchedule>>; - - struct GemmKernel : public KernelType {}; -}; - -template -inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = b.size(0); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, mainloop_args, - epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -using ReductionMode = cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::ReductionMode; -using DecompositionMode = cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::DecompositionMode; -using RasterOrderOptions = cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90Params::RasterOrderOptions; - -template -inline void cutlass_gemm_caller_streamk(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - ReductionMode reduction_mode, - DecompositionMode decomposition_mode, - EpilogueArgs&&... epilogue_params) { - static_assert(std::is_same::value, - "Must be streamk scheduler"); - - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::TileSchedulerArguments tile_scheduler_args( - 1, 1, RasterOrderOptions::Heuristic, decomposition_mode); - tile_scheduler_args.reduction_mode = reduction_mode; - - // Copied from examples... - // The KernelHardwareInfo struct holds the number of SMs on the GPU with a - // given device ID. This information is used by the underlying kernel. - cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with - // multiple GPUs and wish to use a GPU other than that with device ID 0. - hw_info.device_id = 0; - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); - - typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, - mainloop_args, epilogue_args, - hw_info, tile_scheduler_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M128 { - // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; - - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_default { - // For M > 128 and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M128 { - // For M in (64, 128] and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M64 { - // For M in (32, 64] and any N - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NBig { - // For M in [1, 32] and N >= 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NSmall { - // For M in [1, 32] and N < 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template -struct cutlass_3x_simple_gemm { - static const GemmUniversalMode Mode = Mode_; - using ElementAB = ElementAB_; - using ElementD = ElementD_; - - using ElementAcc = - typename std::conditional, AccType, - AccType>::type; - - using StrideD = Stride, Int<0>>; - using ElementC = void; - using StrideC = StrideD; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - // clang-format off - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, - ElementAcc, TileShape, ClusterShape, - Stages, - KernelSchedule>::CollectiveOp; - // clang-format on - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - TileSchedule>>; - - struct GemmKernel : public KernelType {}; -}; - -template -inline void cutlass_simple_gemm_caller(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - - typename GemmKernel::EpilogueArguments epilogue_args{ - {}, c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, mainloop_args, - epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -#endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 1657f7d0b16e8..97a969cf5e3e0 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -137,9 +137,11 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, return; } - // Turing - TORCH_CHECK(version_num >= 75); - cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); + if (version_num >= 75) { + // Turing + cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); + return; + } #endif TORCH_CHECK_NOT_IMPLEMENTED(