Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ 2:4 Sparse ] Cherry Pick Latest Kernel From Faraz #35

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
aT = a.t()
bT = b.t()
bT = b
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out.t(), out_ref):
print("Incorrect result")
if not torch.allclose(out, out_ref, rtol=1e-2, atol=1e-2):
print(f"Incorrect result for {m}, {k}, {n}")
exit()

timers = []
Expand Down
42 changes: 20 additions & 22 deletions benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue)
dtype, m, n, k
)
AsT = [x.t() for x in As]
BsT = [x.t() for x in Bs]
bf16_As = [x.to(dtype=torch.bfloat16) for x in As]
bf16_BsT = [x.to(dtype=torch.bfloat16) for x in BsT]
bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs]
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
# Because the transposed output will be computed
Expand All @@ -266,35 +265,35 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue)
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
ArgPool(bf16_As), ArgPool(bf16_BsT))
ArgPool(bf16_As), ArgPool(bf16_Bs))

elif kernel_type == 'pytorch_scaled_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
ArgPool(As), ArgPool(BsT),
ArgPool(As), ArgPool(Bs),
scale_a=scale_a, scale_b=scale_b,
out_dtype=torch.bfloat16)

elif kernel_type == 'pytorch_scaled_mm_fast':
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
ArgPool(As), ArgPool(BsT),
ArgPool(As), ArgPool(Bs),
scale_a=scale_a, scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True)

elif kernel_type == 'cutlass_scaled_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm_default",
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
ArgPool(As), ArgPool(BsT), scale_a, scale_b,
ArgPool(As), ArgPool(Bs), scale_a, scale_b,
torch.bfloat16)

elif kernel_type == 'cutlass_sparse_mm':
elif kernel_type == 'cutlass_scaled_sparse_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_default",
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
ArgPool(BComps), ArgPool(Es), ArgPool(AsT),
scale_b, scale_a, torch.bfloat16)
Expand Down Expand Up @@ -417,9 +416,8 @@ def run_kernels_on_gpus(configs: List[Dict]) -> List[Tuple[bool, Optional[TMeasu
dtype, m, n, k
)
AsT = [x.t() for x in As]
BsT = [x.t() for x in Bs]
bf16_As = [x.to(dtype=torch.bfloat16) for x in As]
bf16_BsT = [x.to(dtype=torch.bfloat16) for x in BsT]
bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs]
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
out = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda")
Expand All @@ -440,35 +438,35 @@ def run_kernels_on_gpus(configs: List[Dict]) -> List[Tuple[bool, Optional[TMeasu
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
ArgPool(bf16_As), ArgPool(bf16_BsT))
ArgPool(bf16_As), ArgPool(bf16_Bs))

elif kernel_type == 'pytorch_scaled_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
ArgPool(As), ArgPool(BsT),
ArgPool(As), ArgPool(Bs),
scale_a=scale_a, scale_b=scale_b,
out_dtype=torch.bfloat16)

elif kernel_type == 'pytorch_scaled_mm_fast':
bench = BenchMM(cuda_graph_params, label, sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
ArgPool(As), ArgPool(BsT),
ArgPool(As), ArgPool(Bs),
scale_a=scale_a, scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True)

elif kernel_type == 'cutlass_scaled_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm_default",
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
ArgPool(As), ArgPool(BsT), scale_a, scale_b,
ArgPool(As), ArgPool(Bs), scale_a, scale_b,
torch.bfloat16)

elif kernel_type == 'cutlass_sparse_mm':
elif kernel_type == 'cutlass_scaled_sparse_mm':
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_default",
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
ArgPool(BComps), ArgPool(Es), ArgPool(AsT),
scale_b, scale_a, torch.bfloat16)
Expand Down Expand Up @@ -525,11 +523,11 @@ def bench_fp8(dtype: torch.dtype, with_cuda_graph: Optional[int],

# Prepare configs for all kernels
standard_kernels = [
{'kernel_type': 'pytorch_mm'},
{'kernel_type': 'pytorch_scaled_mm'},
{'kernel_type': 'pytorch_scaled_mm_fast'},
# {'kernel_type': 'pytorch_mm'},
# {'kernel_type': 'pytorch_scaled_mm'},
# {'kernel_type': 'pytorch_scaled_mm_fast'},
{'kernel_type': 'cutlass_scaled_mm'},
{'kernel_type': 'cutlass_sparse_mm'}
{'kernel_type': 'cutlass_scaled_sparse_mm'}
]

# Create configs for standard kernels
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/cutlass_benchmarks/sparse_mm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def prune_to_2_4(tensor):
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5

# # Initialize a to all ones
# a = torch.ones((m, k), device='cuda')
# # Initialize b to all ones
# b = torch.ones((n, k), device='cuda')

b = prune_to_2_4(b)
b = prune_to_2_4(b.t()).t()

if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
Expand All @@ -69,7 +69,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
else:
raise ValueError("unsupported dtype")

b_compressed, e = ops.cutlass_compress_entry(b)
b_compressed, e = ops.cutlass_compress_entry(b.t())

# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
# Ns : [2560, 4096, 5120, 6144, 8192, 12288, 14336, 15360,
# 22016, 27648, 28672]
"llama-representative-set": [
# ([4096, 4096], None), # small K, small N
([4096, 4096], None), # small K, small N
([4096, 8192], None), # small K, medium N
([4096, 22016], None), # small K, large N
([14336, 4096], None), # large K, small N
Expand Down
82 changes: 72 additions & 10 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,85 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM512 =
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;

uint32_t const n = b.size(1);
uint32_t const mp2 =

using Cutlass3xGemm1 =
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm2 =
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm3 =
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm4 =
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm5 =
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm6 =
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm7 =
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm8 =
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;

uint32_t const n = b.size(1); // Batch size
uint32_t const m = a.size(0);
uint32_t const np2 =
std::max(static_cast<uint32_t>(64), next_pow_2(n)); // next power of 2

if (mp2 <= 64) {
if (np2 <= 64) {
if (m == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 4096 || m == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
} else if (np2 <= 128) {
if (m == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
} else if (np2 <= 256) {
if (m == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
} else {
if (m == 6144 || m == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
else if (m == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
}

// Otherwise the default heuristic
if (np2 <= 64) {
// n in [1, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
} else if (np2 <= 128) {
// n in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
} else if (np2 <= 256) {
// n in (128, 256]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
Expand Down Expand Up @@ -181,8 +246,8 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
torch::Tensor const& e,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(e.dtype() == torch::kUInt8);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
Expand All @@ -195,7 +260,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
out, a, e, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
Expand All @@ -210,7 +274,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
}
}
else if (a.dtype() == torch::kFloat16) {
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kFloat16);

if (out.dtype() == torch::kBFloat16) {
Expand All @@ -226,7 +289,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co
}
else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(e.dtype() == torch::kUInt8);
TORCH_CHECK(b.dtype() == torch::kBFloat16);

if (out.dtype() == torch::kBFloat16) {
Expand Down
Loading
Loading