Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ilmarkov committed Nov 13, 2024
1 parent 31cf482 commit 68512d4
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 69 deletions.
109 changes: 54 additions & 55 deletions benchmarks/cusparseLt_benchmarks/benchmark_24.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import copy
import itertools
import pickle
import time
from typing import Callable, Iterable, List, Tuple

import torch
Expand All @@ -13,8 +15,6 @@
is_semi_structured_supported, semi_structured_sparse_dense_gemm,
semi_structured_sparse_dense_gemm_scaled)
from vllm.utils import FlexibleArgumentParser
import time
import pickle

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [32, 64, 128, 256, 512]
Expand Down Expand Up @@ -54,21 +54,15 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,

timers = []
# pytorch float16
# timers.append(
# bench_fn(label, sub_label, "pytorch_fp16_fp16_matmul", torch.mm,
# a.to(dtype=torch.float16), b.to(dtype=torch.float16)))

# # pytorch bf16
# timers.append(
# bench_fn(label, sub_label, "pytorch_bf16_bf16_matmul", torch.mm,
# a.to(dtype=torch.bfloat16, device="cuda"),
# b.to(dtype=torch.bfloat16, device="cuda")))
timers.append(
bench_fn(label, sub_label, "pytorch_fp16_fp16_matmul", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))

# # cusparseLt fp16
# timers.append(
# bench_fn(label, sub_label, "cusparseLt_fp16_fp16_2_4",
# semi_structured_sparse_dense_gemm,
# compress_to_torch_sparse_semi_structured_mat(a), b))
# cusparseLt fp16
timers.append(
bench_fn(label, sub_label, "cusparseLt_fp16_fp16_2_4",
semi_structured_sparse_dense_gemm,
compress_to_torch_sparse_semi_structured_mat(a), b))

# timers.append(
# bench_fn(label,
Expand All @@ -79,14 +73,17 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,
# b,
# cached=False))

# # cusparseLt bf16
# a, b = make_rand_tensors(torch.bfloat16, m, n, k)
# a_compressed = compress_to_torch_sparse_semi_structured_mat(
# a.to(dtype=torch.bfloat16))
# pytorch bf16
a, b = make_rand_tensors(torch.bfloat16, m, n, k)
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_matmul", torch.mm, a, b))

# timers.append(
# bench_fn(label, sub_label, "cusparseLt_bf16_bf16_2_4",
# semi_structured_sparse_dense_gemm, a_compressed, b))
# cusparseLt bf16
a_compressed = compress_to_torch_sparse_semi_structured_mat(a)

timers.append(
bench_fn(label, sub_label, "cusparseLt_bf16_bf16_2_4",
semi_structured_sparse_dense_gemm, a_compressed, b))

# timers.append(
# bench_fn(label,
Expand All @@ -99,9 +96,9 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,

a, b = make_rand_tensors(torch.int8, m, n, k)
# # cutlass i8
# timers.append(
# bench_fn(label, sub_label, "cutlass_i8_i8_matmul", dense_matmul, a, b,
# torch.int8))
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_matmul_scaled", dense_matmul,
a, b, torch.int8))

# cusparseLt i8
a_compressed = compress_to_torch_sparse_semi_structured_mat(a)
Expand All @@ -113,38 +110,39 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,
bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4",
semi_structured_sparse_dense_gemm, a_compressed, b))


semi_structured_sparse_dense_gemm_scaled(a_compressed,
b,
scale_a=scale,
scale_b=scale)
timers.append(
bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4_scaled",
semi_structured_sparse_dense_gemm_scaled, a_compressed, b, scale, scale))

scale_vec = scale.repeat(a_compressed.shape[0])
semi_structured_sparse_dense_gemm_scaled(a_compressed,
b,
scale_a=scale_vec,
scale_b=scale)
timers.append(
bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4_scaled_channel",
semi_structured_sparse_dense_gemm_scaled, a_compressed, b, scale_vec, scale))
semi_structured_sparse_dense_gemm_scaled, a_compressed, b,
scale, scale))

# scale_vec = scale.repeat(a_compressed.shape[0])
# semi_structured_sparse_dense_gemm_scaled(a_compressed,
# b,
# scale_a=scale_vec,
# scale_b=scale)
# timers.append(
# bench_fn(label, sub_label, "cusparseLt_i8_i8_2_4_scaled_channel",
# semi_structured_sparse_dense_gemm_scaled, a_compressed, b,
# scale_vec, scale))

timers.append(
bench_fn(label,
sub_label,
"cusparseLt_i8_i8_2_4_noncached",
semi_structured_sparse_dense_gemm,
a_compressed,
b,
cached=False))
# timers.append(
# bench_fn(label,
# sub_label,
# "cusparseLt_i8_i8_2_4_noncached",
# semi_structured_sparse_dense_gemm,
# a_compressed,
# b,
# cached=False))

if use_fp8:
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
# cutlass fp8
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_matmul-w-scales",
bench_fn(label, sub_label, "cutlass_fp8_fp8_matmul_scaled",
dense_matmul, a, b, torch.float8_e4m3fn))

# cusparseLt fp8
Expand All @@ -159,18 +157,18 @@ def bench(m: int, k: int, n: int, label: str, sub_label: str,

semi_structured_sparse_dense_gemm_scaled(a_compressed, b, scale, scale)
timers.append(
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4_scale",
bench_fn(label, sub_label, "cusparseLt_fp8_fp8_2_4_scaled",
semi_structured_sparse_dense_gemm_scaled, a_compressed, b,
scale, scale))

timers.append(
bench_fn(label,
sub_label,
"cusparseLt_fp8_fp8_2_4_noncached",
semi_structured_sparse_dense_gemm,
a_compressed,
b,
cached=False))
# timers.append(
# bench_fn(label,
# sub_label,
# "cusparseLt_fp8_fp8_2_4_noncached",
# semi_structured_sparse_dense_gemm,
# a_compressed,
# b,
# cached=False))

return timers

Expand Down Expand Up @@ -240,6 +238,7 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
with open(f"model_bench-{timestamp}.pkl", "wb") as f:
pickle.dump(all_data, f)


if __name__ == '__main__':

parser = FlexibleArgumentParser(
Expand Down
34 changes: 20 additions & 14 deletions vllm/model_executor/layers/sparsity/utils/cusparse_2_4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ def semi_structured_sparse_dense_gemm(a_packed: torch.Tensor,
assert a_packed.dtype in [
torch.float16, torch.bfloat16, torch.int8, torch.float8_e4m3fn
], f"Semi structured sparse-dense matmul does not support {a_packed.dtype}"
if (b_dense.shape[0] > 1 and b_dense.shape[1] > 1 and b_dense.is_contiguous()) and a_packed.dtype in [
torch.int8, torch.float8_e4m3fn
]:
if (b_dense.shape[0] > 1 and b_dense.shape[1] > 1
and b_dense.is_contiguous()) and a_packed.dtype in [
torch.int8, torch.float8_e4m3fn
]:
raise ValueError("cuSparseLt does not support"
"contiguous dense matrix for int8 and fp8 types")
if a_packed.dtype in [torch.float16, torch.bfloat16]:
Expand All @@ -128,10 +129,12 @@ def semi_structured_sparse_dense_gemm(a_packed: torch.Tensor,

row, col = b_dense.shape
b_dense = _pad_dense_input(b_dense)
if (b_dense.shape[0] > 1 and b_dense.shape[1] > 1 and b_dense.is_contiguous()) and a_packed.dtype in [
torch.int8, torch.float8_e4m3fn
]:
# We have to provide non-contiguous b_dense to cusparseLt for int8 and fp8
if (b_dense.shape[0] > 1 and b_dense.shape[1] > 1
and b_dense.is_contiguous()) and a_packed.dtype in [
torch.int8, torch.float8_e4m3fn
]:
# We have to provide non-contiguous b_dense
# to cusparseLt for int8 and fp8
b_dense = b_dense.t().contiguous().t()

if cached:
Expand Down Expand Up @@ -197,9 +200,10 @@ def semi_structured_sparse_dense_gemm_scaled(a_packed: torch.Tensor,
torch.float16, torch.bfloat16, torch.int8, torch.float8_e4m3fn
], f"Semi structured sparse-dense matmul does not support {a_packed.dtype}"

if (b_dense.shape[0] > 1 and b_dense.shape[1] > 1 and b_dense.is_contiguous()) and a_packed.dtype in [
torch.int8, torch.float8_e4m3fn
]:
if (b_dense.shape[0] > 1 and b_dense.shape[1] > 1
and b_dense.is_contiguous()) and a_packed.dtype in [
torch.int8, torch.float8_e4m3fn
]:
raise ValueError("cuSparseLt does not support "
"contiguous dense matrix for int8 and fp8 types")
if a_packed.dtype in [torch.float16, torch.bfloat16]:
Expand All @@ -211,10 +215,12 @@ def semi_structured_sparse_dense_gemm_scaled(a_packed: torch.Tensor,
row, col = b_dense.shape
b_dense = _pad_dense_input(b_dense)

if (b_dense.shape[0] > 1 and b_dense.shape[1] > 1 and b_dense.is_contiguous()) and a_packed.dtype in [
torch.int8, torch.float8_e4m3fn
]:
# We have to provide non-contiguous b_dense to cusparseLt for int8 and fp8
if (b_dense.shape[0] > 1 and b_dense.shape[1] > 1
and b_dense.is_contiguous()) and a_packed.dtype in [
torch.int8, torch.float8_e4m3fn
]:
# We have to provide non-contiguous b_dense
# to cusparseLt for int8 and fp8
b_dense = b_dense.t().contiguous().t()

per_tensor_weights = (scale_a.numel() == 1)
Expand Down

0 comments on commit 68512d4

Please sign in to comment.