Skip to content

Commit

Permalink
Format vllm code
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 11, 2024
1 parent 72f4577 commit 208b2a0
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 40 deletions.
25 changes: 12 additions & 13 deletions benchmarks/cutlass_benchmarks/sp_fp8_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import copy
import itertools
import pickle as pkl
import time
import dataclasses
import itertools
import multiprocessing as mp
import os
import pickle as pkl
import time
import traceback
from multiprocessing import Process, Queue
from pathlib import Path
Expand All @@ -15,11 +15,11 @@
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES
from vllm.utils import FlexibleArgumentParser
import vllm._custom_ops as ops
from utils import make_n_rand_sparse_tensors
from weight_shapes import WEIGHT_SHAPES

import vllm._custom_ops as ops
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
Expand Down Expand Up @@ -490,8 +490,8 @@ def run_kernels_on_gpus(
bench = BenchMM(cuda_graph_params, label, sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, ArgPool(As),
ArgPool(BComps), ArgPool(Es),
scale_a, scale_b, torch.bfloat16)
ArgPool(BComps), ArgPool(Es), scale_a,
scale_b, torch.bfloat16)

# Run the benchmark
result = bench.run()
Expand Down Expand Up @@ -575,8 +575,8 @@ def bench_fp8(dtype: torch.dtype, with_cuda_graph: Optional[int],


def bench(dtype: torch.dtype, with_cuda_graph: Optional[int],
with_arg_pool: Optional[int], m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
with_arg_pool: Optional[int], m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, with_cuda_graph, with_arg_pool, m, k, n, label,
sub_label)
Expand All @@ -599,9 +599,8 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
if args.with_cuda_graph else label
label = f"{label}-argpool_{args.with_arg_pool}" \
if args.with_arg_pool else label
timers = bench(args.dtype, args.with_cuda_graph,
args.with_arg_pool, m, k, n, label,
f"MKN=({m}x{k}x{n})")
timers = bench(args.dtype, args.with_cuda_graph, args.with_arg_pool, m,
k, n, label, f"MKN=({m}x{k}x{n})")

print_timers(timers)
results.extend(timers)
Expand Down
18 changes: 7 additions & 11 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
}
}

void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out,
torch::Tensor const& a,
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& e,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand All @@ -317,15 +316,12 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out,
}
}

void cutlass_scaled_sparse_mm_azp_sm90(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& e,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) {
void cutlass_scaled_sparse_mm_azp_sm90(
torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& e,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

Expand Down
8 changes: 4 additions & 4 deletions csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));

// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());

if (bias) {
Expand Down
20 changes: 11 additions & 9 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Optional, Type, Tuple
from typing import Optional, Tuple, Type

import pytest
import torch
Expand Down Expand Up @@ -86,8 +86,9 @@ def prune_to_2_4(tensor):
return pruned.reshape(original_shape)


def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5

Expand Down Expand Up @@ -464,19 +465,20 @@ def test_cutlass_sparse_subset():
m, n, k = 512, 512, 512

# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

print("in test")

out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
Expand Down
7 changes: 4 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,8 @@ def cutlass_compress_entry(a: torch.Tensor) \


def cutlass_scaled_sparse_mm(
a: torch.Tensor, # row-major activations
b: torch.Tensor, # row-major weight matrix
a: torch.Tensor, # row-major activations
b: torch.Tensor, # row-major weight matrix
e: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
Expand All @@ -578,7 +578,8 @@ def cutlass_scaled_sparse_mm(
n = a_t.shape[1]
out = torch.empty((n, m), dtype=out_dtype, device=a.device).t()

torch.ops._C.cutlass_scaled_sparse_mm(out, b, e, a_t, scale_b, scale_a, bias)
torch.ops._C.cutlass_scaled_sparse_mm(out, b, e, a_t, scale_b, scale_a,
bias)

return out.t()

Expand Down

0 comments on commit 208b2a0

Please sign in to comment.