diff --git a/CMakeLists.txt b/CMakeLists.txt index 943424bc4edfa..0ebe0bcf4e8aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -202,18 +202,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + # GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(cutlass) @@ -225,7 +225,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gguf/gguf_kernel.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_compressor.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -255,11 +257,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # - # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels + # For Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/sparse/cutlass/sparse_compressor.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -268,12 +273,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " + message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 quantized models on " + "later if you intend on running FP8 quantized models or sparse on " "Hopper.") else() - message(STATUS "Not building scaled_mm_c3x as no compatible archs found " + message(STATUS "Not building cutlass_c3x as no compatible archs found " "in CUDA target architectures") endif() @@ -398,6 +403,9 @@ define_gpu_extension_target( # Setting this variable sidesteps the issue by calling the driver directly. target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) +# include(nm_cutlass_c.cmake) +# build_nm_cutlass_c() + # # _moe_C extension # diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py new file mode 100644 index 0000000000000..22616c2359b74 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py @@ -0,0 +1,311 @@ +## Cutlass benchmark V1 + +from typing import Callable, Iterable + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors + +import vllm._custom_ops as ops + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + a_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + timers = [] + + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, + torch.float16)) + + # cutlass with bias: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + + # Create tensors + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) + aT = a.t() + bT = b.t() + bf16_a = a.to(dtype=torch.bfloat16) + bf16_bT = bT.to(dtype=torch.bfloat16) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + bT.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + bT, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + bT, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + bT, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + bT, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, + torch.bfloat16)) + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16)) + + return timers + + +def bench_fp16(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float16 + a_compressed, e, a, b = make_rand_sparse_tensors(torch.float16, m, n, k) + + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + timers = [] + + # # pytorch impl w. bf16 + # timers.append( + # bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + # torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + # b.to(dtype=torch.bfloat16, device="cuda"))) + + # # pytorch impl: bf16 output + # timers.append( + # bench_fn(label, + # sub_label, + # "pytorch_fp16_fp16_bf16_scaled_mm", + # torch._scaled_mm, + # a, + # b, + # scale_a=scale_a, + # scale_b=scale_b, + # out_dtype=torch.bfloat16)) + + # # pytorch impl: fp16 output + # timers.append( + # bench_fn(label, + # sub_label, + # "pytorch_fp16_fp16_fp16_scaled_mm", + # torch._scaled_mm, + # a, + # b, + # scale_a=scale_a, + # scale_b=scale_b, + # out_dtype=torch.float16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp16_fp16_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp16_fp16_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp16_fp16_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp16_fp16_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + + return timers + + +def bench_bf16(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.bfloat16 + a_compressed, e, a, b = make_rand_sparse_tensors(torch.bfloat16, m, n, k) + + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + timers = [] + + # # pytorch impl w. bf16 + # timers.append( + # bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + # torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + # b.to(dtype=torch.bfloat16, device="cuda"))) + + # # pytorch impl: bf16 output + # timers.append( + # bench_fn(label, + # sub_label, + # "pytorch_fp16_fp16_bf16_scaled_mm", + # torch._scaled_mm, + # a, + # b, + # scale_a=scale_a, + # scale_b=scale_b, + # out_dtype=torch.bfloat16)) + + # # pytorch impl: fp16 output + # timers.append( + # bench_fn(label, + # sub_label, + # "pytorch_fp16_fp16_fp16_scaled_mm", + # torch._scaled_mm, + # a, + # b, + # scale_a=scale_a, + # scale_b=scale_b, + # out_dtype=torch.float16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_bf16_bf16_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_bf16_bf16_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_bf16_bf16_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_bf16_bf16_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + + return timers + + +def bench_v1(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + # if dtype == torch.int8: + # return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + # if dtype == torch.float16: + # return bench_fp16(dtype, m, k, n, label, sub_label) + # if dtype == torch.bfloat16: + # return bench_bf16(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py new file mode 100644 index 0000000000000..f9b4871044526 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v2.py @@ -0,0 +1,556 @@ +import dataclasses +import random +from typing import Any, Callable, Iterable, Optional, Tuple, Dict, List + +import multiprocessing as mp +from multiprocessing import Process, Queue +from queue import Empty + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_n_rand_sparse_tensors + +import vllm._custom_ops as ops +import traceback + +import json +import os +import hashlib +from datetime import datetime +from pathlib import Path + + +@dataclasses.dataclass +class CudaGraphBenchParams: + num_ops_in_cuda_graph: int + + +@dataclasses.dataclass +class ArgPool: + ''' + When some argument of the benchmarking function is annotated with this type, + the benchmarking class (BenchMM) will collapse the argument to a pick a + single value from the given list of values, during function invocation. + + For every invocation during a benchmarking run, it will choose a + different value from the list. + ''' + values: Iterable[Any] + + +class BenchMM: + + class ArgsIterator: + + def __init__(self, args_list, kwargs_list): + assert len(args_list) == len(kwargs_list) + self.args_list = args_list + self.kwargs_list = kwargs_list + self.n = len(self.args_list) + self.idx = 0 + + def __next__(self): + while True: + yield (self.args_list[self.idx], self.kwargs_list[self.idx]) + self.idx += 1 + self.idx = self.idx % self.n + + def reset(self): + self.idx = 0 + + @property + def n_args(self): + return self.n + + def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], + label: str, sub_label: str, description: str, fn: Callable, + *args, **kwargs): + + self.cuda_graph_params = cuda_graph_params + self.use_cuda_graph = self.cuda_graph_params is not None + self.label = label + self.sub_label = sub_label + self.description = description + self.fn = fn + + # Process args + self._args = args + self._kwargs = kwargs + self.args_list, self.kwargs_list = self.collapse_argpool( + *args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, + self.kwargs_list) + + # Cudagraph runner + self.g = None + if self.use_cuda_graph: + self.g = self.get_cuda_graph_runner() + + # benchmark run params + self.min_run_time = 1 + + def collapse_argpool(self, *args, **kwargs): + kwargs = kwargs if kwargs is not None else {} + assert kwargs is None or all([ + not isinstance(v, ArgPool) for k, v in kwargs.items() + ]), 'ArgPools in kwargs are not supported yet' + + arg_pool_indices = [ + i for i, x in enumerate(args) if isinstance(x, ArgPool) + ] + if len(arg_pool_indices) == 0: + return [args], [kwargs] + + # make sure all the Arg pools have the same number of choices + arg_pool_size = len(args[arg_pool_indices[0]].values) + assert all( + [len(args[i].values) == arg_pool_size for i in arg_pool_indices]) + + # create copies of the args + args_list = [] + kwargs_list = [] + for _ in range(arg_pool_size): + args_list.append(args) + kwargs_list.append(kwargs.copy()) + + # collapse the arg pools by simply choosing the ith value + for i in range(arg_pool_size): + assert isinstance(args_list[i], tuple) + # get as list + args_i = list(args_list[i]) + # collapse - make replacements + for arg_pool_idx in arg_pool_indices: + val_from_pool = args_i[arg_pool_idx].values[i] + args_i[arg_pool_idx] = val_from_pool + # store back as tuple + args_list[i] = tuple(args_i) + + return args_list, kwargs_list + + def get_cuda_graph_runner(self): + assert self.use_cuda_graph + assert self.args_iterator is not None + + num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph + + # warmup + args_it = self.args_iterator.__next__() + for _ in range(5): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(num_graph_ops): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + return g + + def run_cudagraph(self) -> TMeasurement: + assert self.use_cuda_graph + globals = {'g': self.g} + + return TBenchmark.Timer( + stmt="g.replay()", + globals=globals, + label=self.label, + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run_eager(self) -> TMeasurement: + setup = None + stmt = None + globals = None + + has_arg_pool = self.args_iterator.n_args > 1 + if has_arg_pool: + setup = ''' + args_iterator.reset() + args_it = args_iterator.__next__() + ''' + stmt = ''' + args, kwargs = next(args_it) + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args_iterator': self.args_iterator} + else: + # no arg pool. Just use the args and kwargs directly + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + args, kwargs = next(args_it) + + setup = "" + stmt = ''' + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} + + return TBenchmark.Timer( + stmt=stmt, + setup=setup, + globals=globals, + label=self.label, + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run(self) -> TMeasurement: + timer = None + if self.use_cuda_graph: # noqa SIM108 + timer = self.run_cudagraph() + else: + timer = self.run_eager() + #assert timer.meets_confidence() + #assert not timer.has_warnings, f"Warnings {timer._warnings}" + if not timer.meets_confidence() or timer.has_warnings: + print("Doesn't meet confidence - re-running bench ...") + return self.run() + return timer + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + print(f"exc type {exc_type}") + print(f"exc value {exc_value}") + print(f"exc traceback {traceback}") + + +def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue): + """ + Run a single kernel benchmark in an isolated process. + Puts (success, result, config) tuple in the queue. + """ + try: + torch.cuda.set_device(gpu_id) + + # Initialize CUDA tensors + m, k, n = kernel_config['m'], kernel_config['k'], kernel_config['n'] + dtype = kernel_config['dtype'] + + # Create tensors + BComps, Es, As, Bs = make_n_rand_sparse_tensors( + kernel_config.get('arg_pool_size', 1), + dtype, m, n, k + ) + AsT = [x.t() for x in As] + BsT = [x.t() for x in Bs] + bf16_As = [x.to(dtype=torch.bfloat16) for x in As] + bf16_BsT = [x.to(dtype=torch.bfloat16) for x in BsT] + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + # Because the transposed output will be computed + out = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda") + + # Setup benchmark params + cuda_graph_params = None + if cgops := kernel_config.get('cuda_graph_ops'): + cuda_graph_params = CudaGraphBenchParams(cgops) + + label = kernel_config['label'] + sub_label = kernel_config['sub_label'] + + # Initialize benchmark based on kernel type + bench = None + kernel_type = kernel_config['kernel_type'] + + if kernel_type == 'pytorch_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + ArgPool(bf16_As), ArgPool(bf16_BsT)) + + elif kernel_type == 'pytorch_scaled_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + ArgPool(As), ArgPool(BsT), + scale_a=scale_a, scale_b=scale_b, + out_dtype=torch.bfloat16) + + elif kernel_type == 'pytorch_scaled_mm_fast': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + ArgPool(As), ArgPool(BsT), + scale_a=scale_a, scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True) + + elif kernel_type == 'cutlass_scaled_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_mm_default", + ops.cutlass_scaled_mm, + ArgPool(As), ArgPool(BsT), scale_a, scale_b, + torch.bfloat16) + + elif kernel_type == 'cutlass_sparse_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_default", + ops.cutlass_scaled_sparse_mm, + ArgPool(BComps), ArgPool(Es), ArgPool(AsT), + scale_b, scale_a, torch.bfloat16) + + + # Run the benchmark + result = bench.run() + queue.put((True, result, kernel_config)) + + except Exception as e: + print(f"Error in benchmark process: {str(e)}") + print(traceback.format_exc()) + queue.put((False, None, kernel_config)) + finally: + # Explicit cleanup + torch.cuda.empty_cache() + +def benchmark_gpu_worker(gpu_id: int, task_queue: Queue, result_queue: Queue): + """Worker process that spawns individual benchmark processes for each kernel.""" + try: + while True: + try: + kernel_config = task_queue.get_nowait() + if kernel_config is None: # Poison pill + break + + # Create a new process queue for this specific benchmark + process_queue = Queue() + + # Create and start a new process for this kernel benchmark + p = Process(target=run_single_benchmark_process, + args=(kernel_config, gpu_id, process_queue)) + p.start() + + # Wait for result with timeout (5 minutes for benchmarking) + try: + success, result, config = process_queue.get(timeout=300) + result_queue.put((success, result, config)) + except Empty: + print(f"Kernel {kernel_config.get('kernel_type')} benchmark timed out") + result_queue.put((False, None, kernel_config)) + + # Cleanup + p.join(timeout=1) # Give it 1 second to join + if p.is_alive(): + p.terminate() + p.join() + + except Empty: + break + except Exception as e: + print(f"Error in GPU {gpu_id} worker: {str(e)}") + print(traceback.format_exc()) + if 'kernel_config' in locals(): + result_queue.put((False, None, kernel_config)) + + finally: + print(f"GPU {gpu_id} worker finished") + +def run_kernels_on_gpus(configs: List[Dict]) -> List[Tuple[bool, Optional[TMeasurement], Dict]]: + MULTI_GPU_MULTI_PROCESS = False # Set to False for single GPU testing + if MULTI_GPU_MULTI_PROCESS: + gpus_list = [0] + task_queue = Queue() + result_queue = Queue() + + configs = configs[:10] + + # Fill task queue + for config in configs: + task_queue.put(config) + for _ in gpus_list: # Add poison pills + task_queue.put(None) + + # Start GPU workers + workers = [] + for gpu_id in gpus_list: + p = Process(target=benchmark_gpu_worker, args=(gpu_id, task_queue, result_queue)) + p.start() + workers.append(p) + + # Collect results + results = [] + completed = 0 + total_tasks = len(configs) + + while completed < total_tasks: + success, result, config = result_queue.get() + results.append((success, result, config)) + completed += 1 + + # Print progress + status = "Success" if success else "Failed" + print(f"{status}: {config['kernel_type']}") + + # Cleanup workers + for worker in workers: + worker.join(timeout=1) + if worker.is_alive(): + worker.terminate() + worker.join() + + return results + else: + """Run kernel benchmarks in a single process.""" + results = [] + gpu_id = 0 # Using the same GPU as before + torch.cuda.set_device(gpu_id) + # configs = configs[:10] # Keep the original slice + + for config in configs: + try: + # Initialize CUDA tensors + m, k, n = config['m'], config['k'], config['n'] + dtype = config['dtype'] + + # Create tensors + BComps, Es, As, Bs = make_n_rand_sparse_tensors( + config.get('arg_pool_size', 1), + dtype, m, n, k + ) + AsT = [x.t() for x in As] + BsT = [x.t() for x in Bs] + bf16_As = [x.to(dtype=torch.bfloat16) for x in As] + bf16_BsT = [x.to(dtype=torch.bfloat16) for x in BsT] + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + out = torch.zeros((n, m), dtype=torch.bfloat16, device="cuda") + + # Setup benchmark params + cuda_graph_params = None + if cgops := config.get('cuda_graph_ops'): + cuda_graph_params = CudaGraphBenchParams(cgops) + + label = config['label'] + sub_label = config['sub_label'] + + # Initialize benchmark based on kernel type + bench = None + kernel_type = config['kernel_type'] + + if kernel_type == 'pytorch_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + ArgPool(bf16_As), ArgPool(bf16_BsT)) + + elif kernel_type == 'pytorch_scaled_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + ArgPool(As), ArgPool(BsT), + scale_a=scale_a, scale_b=scale_b, + out_dtype=torch.bfloat16) + + elif kernel_type == 'pytorch_scaled_mm_fast': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + ArgPool(As), ArgPool(BsT), + scale_a=scale_a, scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True) + + elif kernel_type == 'cutlass_scaled_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_mm_default", + ops.cutlass_scaled_mm, + ArgPool(As), ArgPool(BsT), scale_a, scale_b, + torch.bfloat16) + + elif kernel_type == 'cutlass_sparse_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_default", + ops.cutlass_scaled_sparse_mm, + ArgPool(BComps), ArgPool(Es), ArgPool(AsT), + scale_b, scale_a, torch.bfloat16) + + # Run the benchmark + result = bench.run() + + # Print progress + print(f"Success: {kernel_type}") + + results.append((True, result, config)) + + # Cleanup + torch.cuda.empty_cache() + + except Exception as e: + print(f"Error in benchmark: {str(e)}") + print(traceback.format_exc()) + results.append((False, None, config)) + torch.cuda.empty_cache() + + return results + + +def get_cache_path() -> str: + """Get the path to the cache file for the given configuration hash.""" + return f'{Path(os.path.dirname(os.path.realpath(__file__)))}/stable_kernels.json' + + +def bench_fp8(dtype: torch.dtype, with_cuda_graph: Optional[int], + with_arg_pool: Optional[int], m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + + # Check if context is not set + try: + mp.set_start_method('spawn', force=True) + except RuntimeError: + pass + + timers = [] + gpus_list = [5] # Using the same GPU list as original code + + # Base configuration for all kernels + base_config = { + 'm': m, + 'k': k, + 'n': n, + 'dtype': dtype, + 'cuda_graph_ops': with_cuda_graph, + 'arg_pool_size': with_arg_pool if with_arg_pool else 1, + 'label': label, + 'sub_label': sub_label + } + + # Prepare configs for all kernels + standard_kernels = [ + {'kernel_type': 'pytorch_mm'}, + {'kernel_type': 'pytorch_scaled_mm'}, + {'kernel_type': 'pytorch_scaled_mm_fast'}, + {'kernel_type': 'cutlass_scaled_mm'}, + {'kernel_type': 'cutlass_sparse_mm'} + ] + + # Create configs for standard kernels + all_configs = [{**base_config, **kernel} for kernel in standard_kernels] + + # Run all kernels distributed across GPUs + print(f"Running {len(all_configs)} benchmarks across {len(gpus_list)} GPUs...") + results = run_kernels_on_gpus(all_configs) + + # Process results + for success, result, _ in results: + if success and result is not None: + timers.append(result) + + return timers + + +def bench_v2(dtype: torch.dtype, with_cuda_graph: Optional[int], + with_arg_pool: Optional[int], m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, with_cuda_graph, with_arg_pool, m, k, n, label, + sub_label) + raise ValueError("unsupported type") diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/mm_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_mm/mm_benchmarks.py new file mode 100644 index 0000000000000..82567a57b303a --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_mm/mm_benchmarks.py @@ -0,0 +1,215 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from bench_v1 import bench_v1 +from bench_v2 import bench_v2 +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + dtype = args.dtype + + use_bench_v2 = args.with_cuda_graph or args.with_arg_pool + for m, k, n in MKNs: + if use_bench_v2: + label = f"scaled-sparse-{dtype}-gemm" + label = f"{label}-cugraph_{args.with_cuda_graph}" \ + if args.with_cuda_graph else label + label = f"{label}-argpool_{args.with_arg_pool}" \ + if args.with_arg_pool else label + timers = bench_v2(args.dtype, args.with_cuda_graph, + args.with_arg_pool, m, k, n, label, + f"MKN=({m}x{k}x{n})") + else: + timers = bench_v1(args.dtype, m, k, n, f"scaled-sparse-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + if tp_split_dim is not None: + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + if dt == "fp16": + return torch.float16 + if dt == "bf16": + return torch.bfloat16 + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_mm/mm_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_mm/mm_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_mm/mm_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8', 'fp16', 'bf16']") + parser.add_argument( + '--with-cuda-graph', + type=int, + default=None, + help="Number of ops/matmuls in a cudagraph execution. When set" + "cuda-graphs is enabled") + parser.add_argument( + '--with-arg-pool', + type=int, + default=None, + help="Number of A and B tensors to use as arg-pool. When not set," + "it defaults to 1") + + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/stable_kernels.json b/benchmarks/cutlass_benchmarks/sparse_mm/stable_kernels.json new file mode 100644 index 0000000000000..2cb53b60807a1 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_mm/stable_kernels.json @@ -0,0 +1 @@ +{"date": "2024-11-09T05:36:00.932166", "stable_kernels": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 537, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151, 1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, 1191, 1193, 1194, 1195, 1196, 1197, 1198, 1199, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1219, 1220, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1231, 1232, 1233, 1234, 1235, 1236, 1237, 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1283, 1284, 1285, 1286, 1287, 1288, 1289, 1290, 1291, 1292, 1293, 1294, 1295]} \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/utils.py b/benchmarks/cutlass_benchmarks/sparse_mm/utils.py new file mode 100644 index 0000000000000..2d753b254a0ab --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_mm/utils.py @@ -0,0 +1,87 @@ +# Cutlass bench utils +from typing import Iterable, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + # # Initialize a to all ones + # a = torch.ones((m, k), device='cuda') + # # Initialize b to all ones + # b = torch.ones((n, k), device='cuda') + + b = prune_to_2_4(b) + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_compress_entry(b) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + if b_comp is not None: + ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) + BComps, Es, As, Bs = zip(*ABs) + return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py b/benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py new file mode 100644 index 0000000000000..2999244bf9b95 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_mm/weight_shapes.py @@ -0,0 +1,75 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "meta-llama/Llama-2-70b-tp4-hf": [([8192, 2560], None), ([2048, + 8192], None), + ([8192, 14336], None), + ([7168, 8192], None)], + # The shape space is very big when benchmarking a large set of kernels. + # For example: Let, + # - #kernels to benchmark be 1700 + # - #models to benchmark be 4 (each model has 4 shapes) + # - #batch sizes be 6 (16, 32, 64, 128, 256, 512) + # For 1 kernel, 1 shape and 1 batch-size, H100 takes 1 second (approx.) + # to run, then the benchmark suite would take, + # 1700 * (4 * 4) * 6 = 163200 seconds => 46 hrs. + # Below, we exploit some observation on the benchmark shapes to create a + # representative set. + # + # From previous benchmarking runs, we observe that perf if stratified as, + # N - small, medium, large and K - small and large. We also observe that + # in the model shapes, when K is small, we have small, medium and large Ns. + # when K is large, we only have small Ns. + # + # models : ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-3-8b', + # 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-70b-tp4-hf'] + # Ks : [2048, 4096, 5120, 7168, 8192, 11008, 13824, 14336] + # Ns : [2560, 4096, 5120, 6144, 8192, 12288, 14336, 15360, + # 22016, 27648, 28672] + "llama-representative-set": [ + # ([4096, 4096], None), # small K, small N + ([4096, 8192], None), # small K, medium N + ([4096, 22016], None), # small K, large N + ([14336, 4096], None), # large K, small N + ([8192, 14336], None), # medium K, large N (from llama-2-70b-tp4-hf + ], +} diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 63cf5d50cac75..abcde3b016a7b 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -386,4 +386,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) + args.func(args) \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 25ec9d6028627..d58fb0bf86374 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -40,4 +40,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} +} \ No newline at end of file diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index 25ec9d6028627..77f15891d84b2 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -40,4 +40,36 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], + "meta-llama/Llama-2-70b-tp4-hf": [([8192, 2560], None), ([2048, + 8192], None), + ([8192, 14336], None), + ([7168, 8192], None)], + # The shape space is very big when benchmarking a large set of kernels. + # For example: Let, + # - #kernels to benchmark be 1700 + # - #models to benchmark be 4 (each model has 4 shapes) + # - #batch sizes be 6 (16, 32, 64, 128, 256, 512) + # For 1 kernel, 1 shape and 1 batch-size, H100 takes 1 second (approx.) + # to run, then the benchmark suite would take, + # 1700 * (4 * 4) * 6 = 163200 seconds => 46 hrs. + # Below, we exploit some observation on the benchmark shapes to create a + # representative set. + # + # From previous benchmarking runs, we observe that perf if stratified as, + # N - small, medium, large and K - small and large. We also observe that + # in the model shapes, when K is small, we have small, medium and large Ns. + # when K is large, we only have small Ns. + # + # models : ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-3-8b', + # 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-70b-tp4-hf'] + # Ks : [2048, 4096, 5120, 7168, 8192, 11008, 13824, 14336] + # Ns : [2560, 4096, 5120, 6144, 8192, 12288, 14336, 15360, + # 22016, 27648, 28672] + "llama-representative-set": [ + ([4096, 4096], None), # small K, small N + ([4096, 8192], None), # small K, medium N + ([4096, 22016], None), # small K, large N + ([14336, 4096], None), # large K, small N + ([8192, 14336], None), # medium K, large N (from llama-2-70b-tp4-hf + ], } diff --git a/csrc/ops.h b/csrc/ops.h index c50eb39a3dacc..3c9b814a456e8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -142,6 +142,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); + +bool cutlass_scaled_sparse_mm_supports_fp8(int64_t cuda_device_capability); + +void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + +bool cutlass_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e, + torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 292c9e4b34e1c..84e1f367c8722 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,589 +1,7 @@ -// clang-format will break include orders -// clang-format off -#include - -#if defined CUDA_VERSION && CUDA_VERSION >= 12000 - +#include #include - -#include - -#include -#include -#include - #include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "broadcast_load_epilogue_c3x.hpp" -#include "common.hpp" -// clang-format on - -using namespace cute; - -/* - This file defines quantized GEMM operations using the CUTLASS 3.x API, for - NVIDIA GPUs with sm90a (Hopper) or later. - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm90EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. -*/ - -namespace { - -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); - #endif - } -}; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - template - using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; - - // Don't want to support nullptr by default - template - using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // Don't want to support nullptr by default - template - using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - static_assert(!std::is_same_v> && - !std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - static_assert(std::is_same_v> || - std::is_same_v>); - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch.scaled_mm_. - - A and B may be both either int8 or fp8_e4m3. A can be - quantized per-tensor or per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; - - using StrideD = Stride, Int<0>>; - using ElementC = void; - using StrideC = StrideD; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - // clang-format off - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, - ElementAcc, TileShape, ClusterShape, - Stages, - KernelSchedule>::CollectiveOp; - // clang-format on - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; - - struct GemmKernel : public KernelType {}; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M128 { - // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; - - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_default { - // For M > 128 and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M128 { - // For M in (64, 128] and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M64 { - // For M in (32, 64] and any N - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NBig { - // For M in [1, 32] and N >= 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NSmall { - // For M in [1, 32] and N < 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -} // namespace +#include "scaled_mm_c3x.cuh" template typename Epilogue, @@ -748,4 +166,75 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, } } -#endif +// hyper-parameter sweep kernels + +void cutlass_scaled_mm_sm90_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias) { + assert(!bias); + + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using AccType = float; + + if (out.dtype() == torch::kBFloat16) { + using Cutlass3xGemm = + cutlass_3x_gemm; + + return cutlass_gemm_caller(out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + using Cutlass3xGemm = + cutlass_3x_gemm; + + return cutlass_gemm_caller(out, a, b, a_scales, b_scales); + } +} + +void cutlass_simple_gemm_sm90_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b) { + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using AccType = float; + + if (out.dtype() == torch::kBFloat16) { + using Cutlass3xGemm = + cutlass_3x_simple_gemm; + + return cutlass_simple_gemm_caller(out, a, b); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + using Cutlass3xGemm = + cutlass_3x_simple_gemm; + + return cutlass_simple_gemm_caller(out, a, b); + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh new file mode 100644 index 0000000000000..9b1dd748bfbbe --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh @@ -0,0 +1,777 @@ +#pragma once + +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + +#include + +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "broadcast_load_epilogue_c3x.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This file defines quantized GEMM operations using the CUTLASS 3.x API, for + NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogue functions can be defined to post-process the output before it is + written to GPU memory. + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +// A wrapper for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { + #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); + #endif + } +}; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, + Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, + Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule, typename AccType, + typename TileSchedule = cutlass::gemm::PersistentScheduler, + GemmUniversalMode Mode_ = GemmUniversalMode::kGemm> +struct cutlass_3x_gemm { + static const GemmUniversalMode Mode = Mode_; + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, AccType, + AccType>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Epilogue = Epilogue_; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + TileSchedule>>; + + struct GemmKernel : public KernelType {}; +}; + +template +inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, mainloop_args, + epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode; +using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; +using RasterOrderOptions = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + +template +inline void cutlass_gemm_caller_streamk(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + ReductionMode reduction_mode, + DecompositionMode decomposition_mode, + EpilogueArgs&&... epilogue_params) { + + static_assert(std::is_same::value, "Must be streamk scheduler"); + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::TileSchedulerArguments tile_scheduler_args( + 1, + 1, + RasterOrderOptions::Heuristic, + decomposition_mode + ); + tile_scheduler_args.reduction_mode = reduction_mode; + + // Copied from examples... + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, mainloop_args, + epilogue_args, hw_info, tile_scheduler_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +template typename Epilogue> +struct sm90_fp8_config_default { + // M in (128, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_default { + // For M > 128 and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M128 { + // For M in (64, 128] and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M64 { + // For M in (32, 64] and any N + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NBig { + // For M in [1, 32] and N >= 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NSmall { + // For M in [1, 32] and N < 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template +struct cutlass_3x_simple_gemm { + static const GemmUniversalMode Mode = Mode_; + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, AccType, + AccType>::type; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + TileSchedule>>; + + struct GemmKernel : public KernelType {}; +}; + +template +inline void cutlass_simple_gemm_caller(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, mainloop_args, + epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +#endif diff --git a/csrc/sparse/cutlass/sparse_compressor.cu b/csrc/sparse/cutlass/sparse_compressor.cu new file mode 100644 index 0000000000000..660ee33044d9f --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor.cu @@ -0,0 +1,203 @@ +#include + +#include + +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" + + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/dependent_false.hpp" + +#include "util/broadcast_load_epilogue_c3x.hpp" +#include "util/common.hpp" + +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "util/host_tensor.h" +#include "util/packed_stride.hpp" + +#include "util/helper.h" + +#include "sparse_scaled_mm_c3x.cuh" + +/// Make A structured sparse by replacing elements with 0 and compress it +template +bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, torch::Tensor const& a) +{ + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 || + a.dtype() == torch::kFloat8_e4m3fn || + a.dtype() == torch::kFloat16 || + a.dtype() == torch::kBFloat16); + TORCH_CHECK(a.dim() == 2) + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1) + + int m = a.size(0); + int k = a.size(1); + + using ProblemShape = Shape; + using ElementA = ElementA_; + using LayoutTagA = cutlass::layout::RowMajor; + + // Layouts for reference (non-sparse) tensors + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideE = StrideA; + + using Gemm = + typename std::conditional, + typename sm90_int8_config_default::Cutlass3xGemm, + typename std::conditional, + typename sm90_fp8_config_default::Cutlass3xGemm, + typename std::conditional, + typename sm90_fp16_config_default::Cutlass3xGemm, + typename sm90_bf16_config_default::Cutlass3xGemm + >::type + >::type + >::type; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + // Just a dummy value + int32_t n = 128; + + int64_t lda = a.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; + using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; + + using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; + + LayoutA a_layout = SparseConfig::fill_layoutA(prob_shape); + LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape); + + // typename Gemm::GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; + + // Offline compressor kernel + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm90>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + auto [M, N, K, L] = prob_shape; + + StrideA stride_A; + StrideA stride_A_compressed; + StrideE stride_E; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + + CompressorUtility compressor_utility(prob_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + auto a_ptr = static_cast(a.data_ptr()); + + // cutlass::DeviceAllocation block_A; + // cutlass::DeviceAllocation block_A_compressed; + // cutlass::DeviceAllocation block_E; + + auto a_compressed_ptr = static_cast(a_compressed.data_ptr()); + auto e_ptr = static_cast(e.data_ptr()); + + // block_A_compressed.reset(M * KC * L); + // block_E.reset(ME * KE * L); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + + // // Random sparsification is performed on host + // std::vector block_A_host(m * k); + // cutlass::device_memory::copy_to_host(block_A_host.data(), a_ptr, m * k); + // compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), 2024); + // cutlass::device_memory::copy_to_device(a_ptr, block_A_host.data(), m * k); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments { + prob_shape, + { a_ptr, + stride_A, + a_compressed_ptr, + e_ptr }, + {hw_info} }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +bool cutlass_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e, torch::Tensor const& a) +{ + if (a.dtype() == torch::kBFloat16) { + return sparsify_and_compress(a_compressed, e, a); + } else if (a.dtype() == torch::kFloat16) { + return sparsify_and_compress(a_compressed, e, a); + } else if (a.dtype() == torch::kFloat8_e4m3fn) { + return sparsify_and_compress(a_compressed, e, a); + } + else if (a.dtype() == torch::kInt8) { + return sparsify_and_compress(a_compressed, e, a); + } + return false; +} \ No newline at end of file diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu new file mode 100644 index 0000000000000..a62598587b1b1 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -0,0 +1,285 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + +#include + +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "util/broadcast_load_epilogue_c3x.hpp" +#include "util/common.hpp" +// clang-format on + +#include "sparse_scaled_mm_c3x.cuh" + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(e.dtype() == torch::kUInt8); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm90_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM256 = + typename sm90_fp8_config_M256::Cutlass3xGemm; + using Cutlass3xGemmM512 = + typename sm90_fp8_config_M512::Cutlass3xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } else if (mp2 <= 256) { + // m in (128, 256] + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } else { + // m in (256, inf) + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat16); + TORCH_CHECK(e.dtype() == torch::kUInt8); + TORCH_CHECK(b.dtype() == torch::kFloat16); + + using Cutlass3xGemmDefault = + typename sm90_fp16_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kBFloat16); + TORCH_CHECK(e.dtype() == torch::kUInt8); + TORCH_CHECK(b.dtype() == torch::kBFloat16); + + using Cutlass3xGemmDefault = + typename sm90_bf16_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(e.dtype() == torch::kUInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + + using Cutlass3xGemmDefault = + typename sm90_int8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_int8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_int8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM32NBig = + typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + using Cutlass3xGemmM32NSmall = + typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + + uint32_t const n = out.size(1); + bool const is_small_n = n < 8192; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + + if (mp2 <= 32) { + // m in [1, 32] + if (is_small_n) { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } else { + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } + } else if (mp2 <= 64) { + // m in (32, 64] + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, e, b, std::forward(args)...); + } +} + +template