Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Sparse fused gemm integration (#12)
Browse files Browse the repository at this point in the history
Summary:

Initial integration for the sparse-fused gemm. To achieve this, we need
to ensure that we compress the weight matrix only once and never
decompress it, as decompression is currently unsupported.

Before this change, using `SparseParameter(SparseTensor)` meant that in
`MergedColumnParallelLinear` and `QKVParallelLinear` every time a new
shard was loaded by the `weight_loader` (e.g., the "q" portion of
`QKVParallelLinear`), we would decompress the tensor in-order to use
narrow to update the appropriate section of the weight tensor. With this
change, `SparseParameter(SparseTensor)` is replaced with
`LazyCompressedParameter`, which allows us to operate on
`uncompressed_data` until we explicitly compress it. At that point, the
`uncompressed_data` is compressed into `compressed_data` and freed.
Currently, the detection of when to call compress is somewhat hacky. For
`QKVParallelLinear`, we compress only after inserting "q", "k", and "v"
shard ids, and for `MergedColumnParallelLinear`, we compress once we've
inserted the same number of shards as outputs (determined by
`len(output_sizes)`), which implicitly assumes one shard per output.

Moving away from `SparseParameter(SparseTensor)` means that
`SparseTensor` no longer handles dispatching to the custom ops; instead,
this is handled by `SparseW16A16LinearMethod`. I believe this is a
positive change overall. `SparseTensor` was an unnecessary extra layer
of abstraction/indirection originally designed for the SLoRA work, not
vLLM.

This did result in the 2:4 sparse implementation breaking. However, it
turns out it was already broken (i.e., it was decompressing and running
dense within `SparseTensor`), so we "disable" it for now ("disable"
meaning decompress and run dense instead).

We should revisit all of this infrastructure post-MVP.

---------

Co-authored-by: Andrew Feldman <[email protected]>
  • Loading branch information
2 people authored and tlrmchlsmth committed Feb 21, 2024
1 parent b1d47c1 commit fbfd1aa
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 123 deletions.
54 changes: 27 additions & 27 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
from vllm.model_executor.layers.parameters import SparseParameter, get_param_data
from vllm.model_executor.layers.parameters import LazyCompressedParameter

logger = init_logger(__name__)

Expand Down Expand Up @@ -192,7 +192,7 @@ def __init__(
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
param_data = get_param_data(param)
param_data = param.data

if output_dim is not None:
shard_size = param_data.shape[output_dim]
Expand All @@ -202,9 +202,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

# If SparseParameter, repack dense data as sparse.
if isinstance(param, SparseParameter):
param.pack()
if isinstance(param, LazyCompressedParameter):
param.compress()

def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
Expand Down Expand Up @@ -253,6 +252,7 @@ def __init__(
linear_method: Optional[LinearMethodBase] = None,
):
self.output_sizes = output_sizes
self.loaded_shards = set()
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
Expand All @@ -262,14 +262,9 @@ def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = get_param_data(param)
param_data = param.data
output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None:
if isinstance(param, SparseParameter):
raise NotImplementedError(
"Passing loaded_shard_id=None not yet supported for SparseParameter"
)

# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
Expand Down Expand Up @@ -316,12 +311,17 @@ def weight_loader(self,
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")

self.loaded_shards.add(loaded_shard_id)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

# If Parameter, repack dense data as sparse.
if isinstance(param, SparseParameter):
param.pack()
# This is super hacky for now but we basically want to only compress once all
# of the shards are loaded, right now we just check if the number of shards
# loaded matches the number of outputs expected, assuming one shard per output
all_shards_loaded = (len(self.loaded_shards) == len(self.output_sizes))
if all_shards_loaded and isinstance(param, LazyCompressedParameter):
param.compress()


class QKVParallelLinear(ColumnParallelLinear):
Expand Down Expand Up @@ -365,6 +365,7 @@ def __init__(
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
self.loaded_shards = set()
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
Expand All @@ -385,14 +386,9 @@ def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = get_param_data(param)
param_data = param.data
output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None:
if isinstance(param, SparseParameter):
raise NotImplementedError(
"Passing loaded_shard_id=None not yet supported for SparseParameter"
)

# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
Expand Down Expand Up @@ -456,9 +452,14 @@ def weight_loader(self,
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

# If SparseParameter, repack dense data as sparse.
if isinstance(param, SparseParameter):
param.pack()
self.loaded_shards.add(loaded_shard_id)

# This is super hacky for now but we basically want to only compress once
# all of the shards are loaded, for the QKV matrix this means
# loading shards "q", "k" and "v"
all_shards_loaded = (self.loaded_shards == set(["q", "k", "v"]))
if all_shards_loaded and isinstance(param, LazyCompressedParameter):
param.compress()


class RowParallelLinear(torch.nn.Module):
Expand Down Expand Up @@ -540,7 +541,7 @@ def __init__(
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = get_param_data(param)
param_data = param.data
if input_dim is not None:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
Expand All @@ -549,9 +550,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

# If SparseParameter, repack dense data as sparse.
if isinstance(param, SparseParameter):
param.pack()
if isinstance(param, LazyCompressedParameter):
param.compress()

def forward(self, input_):
# Set up backprop all-reduce.
Expand Down
13 changes: 4 additions & 9 deletions vllm/model_executor/layers/parameters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import torch
from vllm.model_executor.layers.parameters.sparsity import SparseParameter
from vllm.model_executor.layers.parameters.lazy_compressed import LazyCompressedParameter


def get_param_data(param: torch.nn.Parameter) -> torch.Tensor:
"""Gets parameter data in dense format."""
if isinstance(param, SparseParameter):
return param.get_dense_data()
else:
return param.data
__all__ = [
"LazyCompressedParameter",
]
78 changes: 78 additions & 0 deletions vllm/model_executor/layers/parameters/lazy_compressed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy
import torch
from torch.utils._pytree import tree_map

from typing import Type
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat)


class LazyCompressedParameter(torch.Tensor):

@staticmethod
def __new__(cls,
uncompressed_data: torch.Tensor,
storage_format_cls: Type[
CompressedStorageFormat] = SparseBitmaskStorageFormat,
compress_transposed: bool = False):
self = torch.Tensor._make_wrapper_subclass(
cls,
size=uncompressed_data.shape,
dtype=uncompressed_data.dtype,
requires_grad=False)
self.storage_format_cls = storage_format_cls
self.compressed_data = None
self.uncompressed_data = uncompressed_data
self.compress_transposed = compress_transposed
self._is_param = True

return self

@property
def has_compressed_data(self) -> bool:
return (self.compressed_data is not None)

@property
def has_uncompressed_data(self) -> bool:
return (self.uncompressed_data is not None)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
ret_storage_format_cls = None

def unwrap(e):
nonlocal ret_storage_format_cls
if isinstance(e, LazyCompressedParameter):
assert ret_storage_format_cls is None or ret_storage_format_cls == e.storage_format_cls
ret_storage_format_cls = e.storage_format_cls
return e.uncompressed_data if isinstance(
e, LazyCompressedParameter) else e

rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

def wrap(e):
if isinstance(e,
torch.Tensor) and ret_storage_format_cls is not None:
return LazyCompressedParameter(
e, storage_format_cls=ret_storage_format_cls)
return e

rs = tree_map(wrap, rs)
return rs

def compress(self) -> None:
density = torch.count_nonzero(
self.uncompressed_data).item() / numpy.prod(self.shape)

# only compress if we have sufficient sparsity (>=45%), currently
# this applies globally across all formats including 2:4
if (1 - density) < 0.45:
return

if self.uncompressed_data is None:
raise ValueError(
"Called compress() but uncompressed_data does not exist.")
self.compressed_data = self.storage_format_cls.compress(
self.uncompressed_data.t(
) if self.compress_transposed else self.uncompressed_data)
del self.uncompressed_data # free memory
self.uncompressed_data = None
66 changes: 0 additions & 66 deletions vllm/model_executor/layers/parameters/sparsity.py

This file was deleted.

4 changes: 2 additions & 2 deletions vllm/model_executor/layers/sparsity/sparse_w16a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig

from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat)
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat)


class SparseW16A16Config(SparsityConfig):
Expand All @@ -23,7 +23,7 @@ def __repr__(self) -> str:

@classmethod
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
return SparseBitmaskStorageFormat
return SparseBEGemmStorageFormat

@classmethod
def get_name(cls) -> str:
Expand Down
47 changes: 33 additions & 14 deletions vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.parameters import SparseParameter
from magic_wand import (CompressedStorageFormat,
SparseSemiStructuredStorageFormat)
from vllm.model_executor.layers.parameters import LazyCompressedParameter
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat)
from magic_wand.ops import be_ds_gemm


class SparseW16A16LinearMethod(LinearMethodBase):
Expand All @@ -27,10 +27,15 @@ def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
weight = SparseParameter(shape=torch.Size(
(output_size_per_partition, input_size_per_partition)),
dtype=params_dtype,
storage_format_cls=self.storage_format_cls)
supports_linear = (self.storage_format_cls !=
SparseBEGemmStorageFormat)
weight = LazyCompressedParameter(
torch.empty((output_size_per_partition, input_size_per_partition),
dtype=params_dtype),
storage_format_cls=self.storage_format_cls,
# if we don't support F.linear or something analogous,
# transpose when we compress so we can use a basic matmul
compress_transposed=not supports_linear)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})

Expand All @@ -42,14 +47,28 @@ def apply_weights(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sparse_weight = weights["weight"]
w: LazyCompressedParameter = weights["weight"]

if self.storage_format_cls == SparseSemiStructuredStorageFormat:
output = F.linear(x, sparse_weight, bias)
return output
# if we never compressed (likely due to insufficient sparsity),
# i.e. have uncompressed_data run normally
if w.has_uncompressed_data:
assert not w.has_compressed_data
output = F.linear(x, w.uncompressed_data, bias)
# The current 2:4 implementation was running dense so ignore it
# for now and instead just explicitly decompress as usual
# elif self.storage_format_cls == SparseSemiStructuredStorageFormat:
# assert bias is None
# raise NotImplementedError
elif self.storage_format_cls == SparseBEGemmStorageFormat:
assert bias is None
assert w.compress_transposed
out_shape = (x.shape[:-1] + (w.shape[0], ))
reshaped_x = x.reshape(-1, x.shape[-1])
y = be_ds_gemm(reshaped_x, w.compressed_data)
return y.reshape(out_shape)
else:

# Standard matrix multiply
# Uncompress to dense
output = F.linear(x, sparse_weight.to_dense(), bias)
return output
assert not w.compress_transposed
output = F.linear(x, w.compressed_data.decompress(), bias)
return output
Loading

0 comments on commit fbfd1aa

Please sign in to comment.