Skip to content

Commit

Permalink
Merge pull request #35 from neuralmagic/rob/semi-structured
Browse files Browse the repository at this point in the history
[ 2:4 Sparse ] Cherry Pick Latest Kernel From Faraz
  • Loading branch information
dsikka authored Dec 6, 2024
2 parents 34a84a4 + 7c61ab0 commit a8a1b57
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 42 deletions.
6 changes: 3 additions & 3 deletions benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
aT = a.t()
bT = b.t()
bT = b
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out.t(), out_ref):
print("Incorrect result")
if not torch.allclose(out, out_ref, rtol=1e-2, atol=1e-2):
print(f"Incorrect result for {m}, {k}, {n}")
exit()

timers = []
Expand Down
42 changes: 20 additions & 22 deletions benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue)
dtype, m, n, k
)
AsT = [x.t() for x in As]
BsT = [x.t() for x in Bs]
bf16_As = [x.to(dtype=torch.bfloat16) for x in As]
bf16_BsT = [x.to(dtype=torch.bfloat16) for x in BsT]
bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs]
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
# Because the transposed output will be computed
Expand All @@ -266,35 +265,35 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue)
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
ArgPool(bf16_As), ArgPool(bf16_BsT))
ArgPool(bf16_As), ArgPool(bf16_Bs))

elif kernel_type == 'pytorch_scaled_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
ArgPool(As), ArgPool(BsT),
ArgPool(As), ArgPool(Bs),
scale_a=scale_a, scale_b=scale_b,
out_dtype=torch.bfloat16)

elif kernel_type == 'pytorch_scaled_mm_fast':
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
ArgPool(As), ArgPool(BsT),
ArgPool(As), ArgPool(Bs),
scale_a=scale_a, scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True)

elif kernel_type == 'cutlass_scaled_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm_default",
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
ArgPool(As), ArgPool(BsT), scale_a, scale_b,
ArgPool(As), ArgPool(Bs), scale_a, scale_b,
torch.bfloat16)

elif kernel_type == 'cutlass_sparse_mm':
elif kernel_type == 'cutlass_scaled_sparse_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_default",
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
ArgPool(BComps), ArgPool(Es), ArgPool(AsT),
scale_b, scale_a, torch.bfloat16)
Expand Down Expand Up @@ -417,9 +416,8 @@ def run_kernels_on_gpus(configs: List[Dict]) -> List[Tuple[bool, Optional[TMeasu
dtype, m, n, k
)
AsT = [x.t() for x in As]
BsT = [x.t() for x in Bs]
bf16_As = [x.to(dtype=torch.bfloat16) for x in As]
bf16_BsT = [x.to(dtype=torch.bfloat16) for x in BsT]
bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs]
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
out = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda")
Expand All @@ -440,35 +438,35 @@ def run_kernels_on_gpus(configs: List[Dict]) -> List[Tuple[bool, Optional[TMeasu
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
ArgPool(bf16_As), ArgPool(bf16_BsT))
ArgPool(bf16_As), ArgPool(bf16_Bs))

elif kernel_type == 'pytorch_scaled_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
ArgPool(As), ArgPool(BsT),
ArgPool(As), ArgPool(Bs),
scale_a=scale_a, scale_b=scale_b,
out_dtype=torch.bfloat16)

elif kernel_type == 'pytorch_scaled_mm_fast':
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
ArgPool(As), ArgPool(BsT),
ArgPool(As), ArgPool(Bs),
scale_a=scale_a, scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True)

elif kernel_type == 'cutlass_scaled_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm_default",
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
ArgPool(As), ArgPool(BsT), scale_a, scale_b,
ArgPool(As), ArgPool(Bs), scale_a, scale_b,
torch.bfloat16)

elif kernel_type == 'cutlass_sparse_mm':
elif kernel_type == 'cutlass_scaled_sparse_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_default",
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
ArgPool(BComps), ArgPool(Es), ArgPool(AsT),
scale_b, scale_a, torch.bfloat16)
Expand Down Expand Up @@ -525,11 +523,11 @@ def bench_fp8(dtype: torch.dtype, with_cuda_graph: Optional[int],

# Prepare configs for all kernels
standard_kernels = [
{'kernel_type': 'pytorch_mm'},
{'kernel_type': 'pytorch_scaled_mm'},
{'kernel_type': 'pytorch_scaled_mm_fast'},
# {'kernel_type': 'pytorch_mm'},
# {'kernel_type': 'pytorch_scaled_mm'},
# {'kernel_type': 'pytorch_scaled_mm_fast'},
{'kernel_type': 'cutlass_scaled_mm'},
{'kernel_type': 'cutlass_sparse_mm'}
{'kernel_type': 'cutlass_scaled_sparse_mm'}
]

# Create configs for standard kernels
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/cutlass_benchmarks/sparse_mm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def prune_to_2_4(tensor):
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5

# # Initialize a to all ones
# a = torch.ones((m, k), device='cuda')
# # Initialize b to all ones
# b = torch.ones((n, k), device='cuda')

b = prune_to_2_4(b)
b = prune_to_2_4(b.t()).t()

if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
Expand All @@ -69,7 +69,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
else:
raise ValueError("unsupported dtype")

b_compressed, e = ops.cutlass_compress_entry(b)
b_compressed, e = ops.cutlass_compress_entry(b.t())

# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
# Ns : [2560, 4096, 5120, 6144, 8192, 12288, 14336, 15360,
# 22016, 27648, 28672]
"llama-representative-set": [
# ([4096, 4096], None), # small K, small N
([4096, 4096], None), # small K, small N
([4096, 8192], None), # small K, medium N
([4096, 22016], None), # small K, large N
([14336, 4096], None), # large K, small N
Expand Down
82 changes: 72 additions & 10 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,85 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM512 =
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;

uint32_t const n = b.size(1);
uint32_t const mp2 =

using Cutlass3xGemm1 =
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm2 =
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm3 =
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm4 =
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm5 =
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm6 =
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm7 =
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm8 =
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;

uint32_t const n = b.size(1); // Batch size
uint32_t const m = a.size(0);
uint32_t const np2 =
std::max(static_cast<uint32_t>(64), next_pow_2(n)); // next power of 2

if (mp2 <= 64) {
if (np2 <= 64) {
if (m == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 4096 || m == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
} else if (np2 <= 128) {
if (m == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
} else if (np2 <= 256) {
if (m == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
} else {
if (m == 6144 || m == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
}

// Otherwise the default heuristic
if (np2 <= 64) {
// n in [1, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
} else if (np2 <= 128) {
// n in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
} else if (np2 <= 256) {
// n in (128, 256]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
Expand Down Expand Up @@ -181,8 +246,8 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
torch::Tensor const& e,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(e.dtype() == torch::kUInt8);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
Expand All @@ -195,7 +260,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
out, a, e, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
Expand All @@ -210,7 +274,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
}
}
else if (a.dtype() == torch::kFloat16) {
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kFloat16);

if (out.dtype() == torch::kBFloat16) {
Expand All @@ -226,7 +289,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
}
else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kBFloat16);

if (out.dtype() == torch::kBFloat16) {
Expand Down
Loading

0 comments on commit a8a1b57

Please sign in to comment.