From 4045bdabe4e7ea357d5263665a2c3362e62d854e Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Wed, 20 Nov 2024 23:57:56 +0000 Subject: [PATCH 1/2] Add cherry-picked heuristic for Llama3 8B model --- .../cutlass_benchmarks/sparse_mm/bench_v1.py | 6 +- .../cutlass_benchmarks/sparse_mm/bench_v2.py | 42 +++---- .../cutlass_benchmarks/sparse_mm/utils.py | 6 +- .../sparse_mm/weight_shapes.py | 2 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 82 +++++++++++-- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 114 ++++++++++++++++++ csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 2 +- vllm/_custom_ops.py | 2 +- 8 files changed, 215 insertions(+), 41 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py index d7d585bb6956d..938611dda7447 100644 --- a/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py +++ b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py @@ -96,7 +96,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, # Create tensors b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) aT = a.t() - bT = b.t() + bT = b scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) @@ -104,8 +104,8 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16) out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.bfloat16) - if not torch.allclose(out.t(), out_ref): - print("Incorrect result") + if not torch.allclose(out, out_ref, rtol=1e-2, atol=1e-2): + print(f"Incorrect result for {m}, {k}, {n}") exit() timers = [] diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py index f9b4871044526..19e85da657ec1 100644 --- a/benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py +++ b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py @@ -242,9 +242,8 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue) dtype, m, n, k ) AsT = [x.t() for x in As] - BsT = [x.t() for x in Bs] bf16_As = [x.to(dtype=torch.bfloat16) for x in As] - bf16_BsT = [x.to(dtype=torch.bfloat16) for x in BsT] + bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs] scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) # Because the transposed output will be computed @@ -266,13 +265,13 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue) bench = BenchMM(cuda_graph_params, label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", torch.mm, - ArgPool(bf16_As), ArgPool(bf16_BsT)) + ArgPool(bf16_As), ArgPool(bf16_Bs)) elif kernel_type == 'pytorch_scaled_mm': bench = BenchMM(cuda_graph_params, label, sub_label, "pytorch_fp8_fp8_bf16_scaled_mm", torch._scaled_mm, - ArgPool(As), ArgPool(BsT), + ArgPool(As), ArgPool(Bs), scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16) @@ -280,21 +279,21 @@ def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue) bench = BenchMM(cuda_graph_params, label, sub_label, "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", torch._scaled_mm, - ArgPool(As), ArgPool(BsT), + ArgPool(As), ArgPool(Bs), scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16, use_fast_accum=True) elif kernel_type == 'cutlass_scaled_mm': bench = BenchMM(cuda_graph_params, label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_mm_default", + "cutlass_fp8_fp8_bf16_scaled_mm", ops.cutlass_scaled_mm, - ArgPool(As), ArgPool(BsT), scale_a, scale_b, + ArgPool(As), ArgPool(Bs), scale_a, scale_b, torch.bfloat16) - elif kernel_type == 'cutlass_sparse_mm': + elif kernel_type == 'cutlass_scaled_sparse_mm': bench = BenchMM(cuda_graph_params, label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm_default", + "cutlass_fp8_fp8_bf16_scaled_sparse_mm", ops.cutlass_scaled_sparse_mm, ArgPool(BComps), ArgPool(Es), ArgPool(AsT), scale_b, scale_a, torch.bfloat16) @@ -417,9 +416,8 @@ def run_kernels_on_gpus(configs: List[Dict]) -> List[Tuple[bool, Optional[TMeasu dtype, m, n, k ) AsT = [x.t() for x in As] - BsT = [x.t() for x in Bs] bf16_As = [x.to(dtype=torch.bfloat16) for x in As] - bf16_BsT = [x.to(dtype=torch.bfloat16) for x in BsT] + bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs] scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) out = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda") @@ -440,13 +438,13 @@ def run_kernels_on_gpus(configs: List[Dict]) -> List[Tuple[bool, Optional[TMeasu bench = BenchMM(cuda_graph_params, label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", torch.mm, - ArgPool(bf16_As), ArgPool(bf16_BsT)) + ArgPool(bf16_As), ArgPool(bf16_Bs)) elif kernel_type == 'pytorch_scaled_mm': bench = BenchMM(cuda_graph_params, label, sub_label, "pytorch_fp8_fp8_bf16_scaled_mm", torch._scaled_mm, - ArgPool(As), ArgPool(BsT), + ArgPool(As), ArgPool(Bs), scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16) @@ -454,21 +452,21 @@ def run_kernels_on_gpus(configs: List[Dict]) -> List[Tuple[bool, Optional[TMeasu bench = BenchMM(cuda_graph_params, label, sub_label, "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", torch._scaled_mm, - ArgPool(As), ArgPool(BsT), + ArgPool(As), ArgPool(Bs), scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16, use_fast_accum=True) elif kernel_type == 'cutlass_scaled_mm': bench = BenchMM(cuda_graph_params, label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_mm_default", + "cutlass_fp8_fp8_bf16_scaled_mm", ops.cutlass_scaled_mm, - ArgPool(As), ArgPool(BsT), scale_a, scale_b, + ArgPool(As), ArgPool(Bs), scale_a, scale_b, torch.bfloat16) - elif kernel_type == 'cutlass_sparse_mm': + elif kernel_type == 'cutlass_scaled_sparse_mm': bench = BenchMM(cuda_graph_params, label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm_default", + "cutlass_fp8_fp8_bf16_scaled_sparse_mm", ops.cutlass_scaled_sparse_mm, ArgPool(BComps), ArgPool(Es), ArgPool(AsT), scale_b, scale_a, torch.bfloat16) @@ -525,11 +523,11 @@ def bench_fp8(dtype: torch.dtype, with_cuda_graph: Optional[int], # Prepare configs for all kernels standard_kernels = [ - {'kernel_type': 'pytorch_mm'}, - {'kernel_type': 'pytorch_scaled_mm'}, - {'kernel_type': 'pytorch_scaled_mm_fast'}, + # {'kernel_type': 'pytorch_mm'}, + # {'kernel_type': 'pytorch_scaled_mm'}, + # {'kernel_type': 'pytorch_scaled_mm_fast'}, {'kernel_type': 'cutlass_scaled_mm'}, - {'kernel_type': 'cutlass_sparse_mm'} + {'kernel_type': 'cutlass_scaled_sparse_mm'} ] # Create configs for standard kernels diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/utils.py b/benchmarks/cutlass_benchmarks/sparse_mm/utils.py index 0c7bde70412c7..0bc0816b1af62 100644 --- a/benchmarks/cutlass_benchmarks/sparse_mm/utils.py +++ b/benchmarks/cutlass_benchmarks/sparse_mm/utils.py @@ -49,14 +49,14 @@ def prune_to_2_4(tensor): def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, k: int) -> Tuple[torch.Tensor, torch.Tensor]: a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 # # Initialize a to all ones # a = torch.ones((m, k), device='cuda') # # Initialize b to all ones # b = torch.ones((n, k), device='cuda') - b = prune_to_2_4(b) + b = prune_to_2_4(b.t()).t() if dtype == torch.int8: a, b = to_int8(a), to_int8(b) @@ -69,7 +69,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, else: raise ValueError("unsupported dtype") - b_compressed, e = ops.cutlass_compress_entry(b) + b_compressed, e = ops.cutlass_compress_entry(b.t()) # Compressed B, Metadata, Original A, B return b_compressed, e, a, b diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py b/benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py index 2999244bf9b95..77f15891d84b2 100644 --- a/benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py @@ -66,7 +66,7 @@ # Ns : [2560, 4096, 5120, 6144, 8192, 12288, 14336, 15360, # 22016, 27648, 28672] "llama-representative-set": [ - # ([4096, 4096], None), # small K, small N + ([4096, 4096], None), # small K, small N ([4096, 8192], None), # small K, medium N ([4096, 22016], None), # small K, large N ([14336, 4096], None), # large K, small N diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 31f8392412d83..0b49ac1719a4f 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -52,20 +52,85 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, typename sm90_fp8_config_M256::Cutlass3xGemm; using Cutlass3xGemmM512 = typename sm90_fp8_config_M512::Cutlass3xGemm; - - uint32_t const n = b.size(1); - uint32_t const mp2 = + + using Cutlass3xGemm1 = + typename sm90_fp8_config_1::Cutlass3xGemm; + using Cutlass3xGemm2 = + typename sm90_fp8_config_2::Cutlass3xGemm; + using Cutlass3xGemm3 = + typename sm90_fp8_config_3::Cutlass3xGemm; + using Cutlass3xGemm4 = + typename sm90_fp8_config_4::Cutlass3xGemm; + using Cutlass3xGemm5 = + typename sm90_fp8_config_5::Cutlass3xGemm; + using Cutlass3xGemm6 = + typename sm90_fp8_config_6::Cutlass3xGemm; + using Cutlass3xGemm7 = + typename sm90_fp8_config_7::Cutlass3xGemm; + using Cutlass3xGemm8 = + typename sm90_fp8_config_8::Cutlass3xGemm; + + uint32_t const n = b.size(1); // Batch size + uint32_t const m = a.size(0); + uint32_t const np2 = std::max(static_cast(64), next_pow_2(n)); // next power of 2 - if (mp2 <= 64) { + if (np2 <= 64) { + if (m == 28672) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + else if (m == 4096 || m == 6144) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + } else if (np2 <= 128) { + if (m == 4096) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + else if (m == 28672) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + else if (m == 6144) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + } else if (np2 <= 256) { + if (m == 4096) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + else if (m == 28672) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + else if (m == 6144) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + } else { + if (m == 6144 || m == 28672) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + else if (m == 4096) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + } + + // Otherwise the default heuristic + if (np2 <= 64) { // n in [1, 64] return cutlass_sparse_gemm_caller( out, a, e, b, std::forward(args)...); - } else if (mp2 <= 128) { + } else if (np2 <= 128) { // n in (64, 128] return cutlass_sparse_gemm_caller( out, a, e, b, std::forward(args)...); - } else if (mp2 <= 256) { + } else if (np2 <= 256) { // n in (128, 256] return cutlass_sparse_gemm_caller( out, a, e, b, std::forward(args)...); @@ -181,8 +246,8 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co torch::Tensor const& e, torch::Tensor const& b, EpilogueArgs&&... epilogue_args) { + TORCH_CHECK(e.dtype() == torch::kUInt8); if (a.dtype() == torch::kInt8) { - TORCH_CHECK(e.dtype() == torch::kUInt8); TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { @@ -195,7 +260,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co out, a, e, b, std::forward(epilogue_args)...); } } else if (a.dtype() == torch::kFloat8_e4m3fn) { - TORCH_CHECK(e.dtype() == torch::kUInt8); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { @@ -210,7 +274,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co } } else if (a.dtype() == torch::kFloat16) { - TORCH_CHECK(e.dtype() == torch::kUInt8); TORCH_CHECK(b.dtype() == torch::kFloat16); if (out.dtype() == torch::kBFloat16) { @@ -226,7 +289,6 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co } else { // a.dtype() == torch::kBFloat16 TORCH_CHECK(a.dtype() == torch::kBFloat16); - TORCH_CHECK(e.dtype() == torch::kUInt8); TORCH_CHECK(b.dtype() == torch::kBFloat16); if (out.dtype() == torch::kBFloat16) { diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 2f303677c0a8d..ad103e9151ca3 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -512,6 +512,120 @@ struct sm90_bf16_config_default { KernelSchedule, EpilogueSchedule, float>; }; +//////////////////////// Cherry-Picking Kernels //////////////////////// +template typename Epilogue> +struct sm90_fp8_config_1 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64,_64,_256>; + using ClusterShape = Shape<_8,_1,_1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_2 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileShape = Shape<_128,_64,_256>; + using ClusterShape = Shape<_8,_1,_1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_3 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64,_64,_256>; + using ClusterShape = Shape<_1,_2,_1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_4 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileShape = Shape<_64,_128,_256>; + using ClusterShape = Shape<_8,_1,_1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_5 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128,_128,_256>; + using ClusterShape = Shape<_8,_1,_1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_6 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64,_128,_256>; + using ClusterShape = Shape<_1,_2,_1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_7 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileShape = Shape<_128,_128,_256>; + using ClusterShape = Shape<_1,_1,_1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_8 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileShape = Shape<_128,_256,_128>; + using ClusterShape = Shape<_8,_1,_1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; +//////////////////////////////////////////////////////////////////////// + template typename Epilogue> struct sm90_fp8_config_default { diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu index aaf964fe3f4c0..e5a80e75a6b0b 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu @@ -52,7 +52,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, // Check for strides and alignment TORCH_CHECK(a.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major + // TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major // TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ff6fa7583789d..91a5cbe7321f1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -557,7 +557,7 @@ def cutlass_scaled_sparse_mm(a: torch.Tensor, torch.ops._C.cutlass_scaled_sparse_mm(out, a, e, b, scale_a, scale_b, bias) - return out + return out.t() # aqlm From 7c61ab014593c2aa76ad21e16dd8765c7b375b06 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 3 Dec 2024 20:47:07 +0000 Subject: [PATCH 2/2] updated with latest kernel --- .../compressed_tensors/schemes/compressed_tensors_24.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index e3df8317bdb99..8607464bd9dc0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -170,7 +170,6 @@ def apply_weights(self, bias=bias ) - out = out.t() assert out.is_contiguous() return out