Skip to content

Commit

Permalink
Address reviews; one compression test left to pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 13, 2024
1 parent b559b6a commit 4c927a0
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 108 deletions.
10 changes: 4 additions & 6 deletions benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,14 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")

timers = []
# pytorch impl - bfloat16
Expand Down Expand Up @@ -105,15 +104,14 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")

timers = []

Expand Down
7 changes: 7 additions & 0 deletions csrc/core/math.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <climits>
#include <iostream>

inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
19 changes: 5 additions & 14 deletions csrc/cutlass_extensions/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,16 @@
#define CUTLASS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
cutlassGetStatusString(status)); \
}

inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

/**
* Panic wrapper for unwinding CUDA runtime errors
*/
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
}

inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
c10::optional<torch::Tensor> const& bias);

bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e,
torch::Tensor const& a);
torch::Tensor const& a);
#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
Expand Down
1 change: 1 addition & 0 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"

#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on

Expand Down
1 change: 1 addition & 0 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"

#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on

Expand Down
34 changes: 5 additions & 29 deletions csrc/sparse/cutlass/sparse_compressor.cu
Original file line number Diff line number Diff line change
@@ -1,46 +1,22 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>

#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>

#include <iostream>
#include <sstream>
#include <vector>

#include "cutlass/cutlass.h"
#include "sparse_scaled_mm_c3x.cuh"

#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/detail/dependent_false.hpp"

#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/common.hpp"

#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"

#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"

#include <iostream>

#include "cutlass/cutlass.h"

#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"

#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on

using namespace cute;
using namespace vllm;
Expand Down Expand Up @@ -153,7 +129,7 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
}

bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e,
torch::Tensor const& a) {
torch::Tensor const& a) {
if (a.dtype() == torch::kBFloat16) {
return sparsify_and_compress<cutlass::bfloat16_t>(a_compressed, e, a);
} else if (a.dtype() == torch::kFloat16) {
Expand Down
48 changes: 17 additions & 31 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,10 @@
#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12000

#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>

#include <iostream>
#include <sstream>
#include <vector>

#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"

#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"

#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on

#include "sparse_scaled_mm_c3x.cuh"

using namespace cute;
using namespace vllm;

Expand Down Expand Up @@ -247,37 +225,43 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat16) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);

if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
Epilogue>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16);
Expand All @@ -286,12 +270,14 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
Expand Down
36 changes: 17 additions & 19 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Original file line number Diff line number Diff line change
@@ -1,29 +1,24 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>

#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>

#include <iostream>
#include <sstream>
#include <vector>

#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"

#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"

#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/torch_utils.hpp"
// clang-format on

using namespace cute;

Expand Down Expand Up @@ -94,8 +89,9 @@ struct cutlass_sparse_3x_gemm {
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD, ElementD,
LayoutD_Transpose, AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD,
ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule,
EVTCompute>::CollectiveOp;

static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
Expand All @@ -121,8 +117,7 @@ struct cutlass_sparse_3x_gemm {
};

template <typename Gemm, typename... EpilogueArgs>
void cutlass_sparse_gemm_caller(torch::Tensor& out,
torch::Tensor const& a,
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... epilogue_params) {
Expand All @@ -145,7 +140,8 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out,
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();

using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{(int) bt_nzs.size(0), (int) size<0>(layout_A), (int) size<1>(layout_A), 1};
typename GemmKernel::ProblemShape prob_shape{
(int)bt_nzs.size(0), (int)size<0>(layout_A), (int)size<1>(layout_A), 1};

using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
Expand Down Expand Up @@ -198,7 +194,7 @@ struct sm90_config_default<half_t, OutType, Epilogue> {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename OutType,
Expand All @@ -210,8 +206,9 @@ struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
float>;
};

//////////////////////// Cherry-Picking Kernels ////////////////////////
Expand Down Expand Up @@ -337,8 +334,9 @@ struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand Down
10 changes: 6 additions & 4 deletions csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == bt_nzs.size(0));

// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 &&
c.stride(1) == 1); // Row-major
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

if (bias) {
Expand All @@ -52,7 +53,8 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales, bias);
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
bias);
return;
}
#endif
Expand Down
Loading

0 comments on commit 4c927a0

Please sign in to comment.