Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Sparse fused gemm integration #12

Merged
merged 23 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
from vllm.model_executor.layers.parameters import SparseParameter, get_param_data
from vllm.model_executor.layers.parameters import LazyCompressedParameter

logger = init_logger(__name__)

Expand Down Expand Up @@ -196,7 +196,7 @@ def __init__(
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
param_data = get_param_data(param)
param_data = param.data

if output_dim is not None:
shard_size = param_data.shape[output_dim]
Expand All @@ -206,9 +206,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

# If SparseParameter, repack dense data as sparse.
if isinstance(param, SparseParameter):
param.pack()
if isinstance(param, LazyCompressedParameter):
param.compress()

def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
Expand Down Expand Up @@ -257,6 +256,7 @@ def __init__(
linear_method: Optional[LinearMethodBase] = None,
):
self.output_sizes = output_sizes
self.loaded_shards = set()
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
Expand All @@ -266,14 +266,9 @@ def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = get_param_data(param)
param_data = param.data
output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None:
if isinstance(param, SparseParameter):
raise NotImplementedError(
"Passing loaded_shard_id=None not yet supported for SparseParameter"
)

# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
Expand Down Expand Up @@ -320,12 +315,17 @@ def weight_loader(self,
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")

self.loaded_shards.add(loaded_shard_id)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

# If Parameter, repack dense data as sparse.
if isinstance(param, SparseParameter):
param.pack()
# This is super hacky for now but we basically want to only compress once all
# of the shards are loaded, right now we just check if the number of shards
# loaded matches the number of outputs expected, assuming one shard per output
all_shards_loaded = (len(self.loaded_shards) == len(self.output_sizes))
if all_shards_loaded and isinstance(param, LazyCompressedParameter):
param.compress()


class QKVParallelLinear(ColumnParallelLinear):
Expand Down Expand Up @@ -369,6 +369,7 @@ def __init__(
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
self.loaded_shards = set()
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
Expand All @@ -389,14 +390,9 @@ def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = get_param_data(param)
param_data = param.data
output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None:
if isinstance(param, SparseParameter):
raise NotImplementedError(
"Passing loaded_shard_id=None not yet supported for SparseParameter"
)

# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
Expand Down Expand Up @@ -460,9 +456,14 @@ def weight_loader(self,
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

# If SparseParameter, repack dense data as sparse.
if isinstance(param, SparseParameter):
param.pack()
self.loaded_shards.add(loaded_shard_id)

# This is super hacky for now but we basically want to only compress once
# all of the shards are loaded, for the QKV matrix this means
# loading shards "q", "k" and "v"
all_shards_loaded = (self.loaded_shards == set(["q", "k", "v"]))
if all_shards_loaded and isinstance(param, LazyCompressedParameter):
param.compress()
mgoin marked this conversation as resolved.
Show resolved Hide resolved


class RowParallelLinear(torch.nn.Module):
Expand Down Expand Up @@ -546,7 +547,7 @@ def __init__(
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = get_param_data(param)
param_data = param.data
if input_dim is not None:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
Expand All @@ -555,9 +556,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

# If SparseParameter, repack dense data as sparse.
if isinstance(param, SparseParameter):
param.pack()
if isinstance(param, LazyCompressedParameter):
param.compress()

def forward(self, input_):
# Set up backprop all-reduce.
Expand Down
13 changes: 4 additions & 9 deletions vllm/model_executor/layers/parameters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import torch
from vllm.model_executor.layers.parameters.sparsity import SparseParameter
from vllm.model_executor.layers.parameters.lazy_compressed import LazyCompressedParameter


def get_param_data(param: torch.nn.Parameter) -> torch.Tensor:
"""Gets parameter data in dense format."""
if isinstance(param, SparseParameter):
return param.get_dense_data()
else:
return param.data
__all__ = [
"LazyCompressedParameter",
]
78 changes: 78 additions & 0 deletions vllm/model_executor/layers/parameters/lazy_compressed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy
import torch
from torch.utils._pytree import tree_map

from typing import Type
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat)


class LazyCompressedParameter(torch.Tensor):

@staticmethod
def __new__(cls,
uncompressed_data: torch.Tensor,
storage_format_cls: Type[
CompressedStorageFormat] = SparseBitmaskStorageFormat,
compress_transposed: bool = False):
self = torch.Tensor._make_wrapper_subclass(
cls,
size=uncompressed_data.shape,
dtype=uncompressed_data.dtype,
requires_grad=False)
self.storage_format_cls = storage_format_cls
self.compressed_data = None
self.uncompressed_data = uncompressed_data
self.compress_transposed = compress_transposed
self._is_param = True

return self

@property
def has_compressed_data(self) -> bool:
return (self.compressed_data is not None)

@property
def has_uncompressed_data(self) -> bool:
return (self.uncompressed_data is not None)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
ret_storage_format_cls = None

def unwrap(e):
nonlocal ret_storage_format_cls
if isinstance(e, LazyCompressedParameter):
assert ret_storage_format_cls is None or ret_storage_format_cls == e.storage_format_cls
ret_storage_format_cls = e.storage_format_cls
return e.uncompressed_data if isinstance(
e, LazyCompressedParameter) else e

rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

def wrap(e):
if isinstance(e,
torch.Tensor) and ret_storage_format_cls is not None:
return LazyCompressedParameter(
e, storage_format_cls=ret_storage_format_cls)
return e

rs = tree_map(wrap, rs)
return rs

def compress(self) -> None:
density = torch.count_nonzero(
self.uncompressed_data).item() / numpy.prod(self.shape)

# only compress if we have sufficient sparsity (>=45%), currently
# this applies globally across all formats including 2:4
if (1 - density) < 0.45:
return

if self.uncompressed_data is None:
raise ValueError(
"Called compress() but uncompressed_data does not exist.")
self.compressed_data = self.storage_format_cls.compress(
self.uncompressed_data.t(
) if self.compress_transposed else self.uncompressed_data)
del self.uncompressed_data # free memory
self.uncompressed_data = None
66 changes: 0 additions & 66 deletions vllm/model_executor/layers/parameters/sparsity.py

This file was deleted.

4 changes: 2 additions & 2 deletions vllm/model_executor/layers/sparsity/sparse_w16a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig

from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat)
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat)


class SparseW16A16Config(SparsityConfig):
Expand All @@ -23,7 +23,7 @@ def __repr__(self) -> str:

@classmethod
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
return SparseBitmaskStorageFormat
return SparseBEGemmStorageFormat

@classmethod
def get_name(cls) -> str:
Expand Down
47 changes: 33 additions & 14 deletions vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.parameters import SparseParameter
from magic_wand import (CompressedStorageFormat,
SparseSemiStructuredStorageFormat)
from vllm.model_executor.layers.parameters import LazyCompressedParameter
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat)
from magic_wand.ops import be_ds_gemm


class SparseW16A16LinearMethod(LinearMethodBase):
Expand All @@ -27,10 +27,15 @@ def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
weight = SparseParameter(shape=torch.Size(
(output_size_per_partition, input_size_per_partition)),
dtype=params_dtype,
storage_format_cls=self.storage_format_cls)
supports_linear = (self.storage_format_cls !=
SparseBEGemmStorageFormat)
weight = LazyCompressedParameter(
torch.empty((output_size_per_partition, input_size_per_partition),
dtype=params_dtype),
storage_format_cls=self.storage_format_cls,
# if we don't support F.linear or something analogous,
# transpose when we compress so we can use a basic matmul
compress_transposed=not supports_linear)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})

Expand All @@ -42,14 +47,28 @@ def apply_weights(
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sparse_weight = weights["weight"]
w: LazyCompressedParameter = weights["weight"]

if self.storage_format_cls == SparseSemiStructuredStorageFormat:
output = F.linear(x, sparse_weight, bias)
return output
# if we never compressed (likely due to insufficient sparsity),
# i.e. have uncompressed_data run normally
if w.has_uncompressed_data:
assert not w.has_compressed_data
output = F.linear(x, w.uncompressed_data, bias)
# The current 2:4 implementation was running dense so ignore it
# for now and instead just explicitly decompress as usual
# elif self.storage_format_cls == SparseSemiStructuredStorageFormat:
# assert bias is None
# raise NotImplementedError
elif self.storage_format_cls == SparseBEGemmStorageFormat:
assert bias is None
assert w.compress_transposed
out_shape = (x.shape[:-1] + (w.shape[0], ))
reshaped_x = x.reshape(-1, x.shape[-1])
y = be_ds_gemm(reshaped_x, w.compressed_data)
return y.reshape(out_shape)
else:

# Standard matrix multiply
# Uncompress to dense
output = F.linear(x, sparse_weight.to_dense(), bias)
return output
assert not w.compress_transposed
output = F.linear(x, w.compressed_data.decompress(), bias)
return output
Loading
Loading