This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Initial integration for the sparse-fused gemm. To achieve this, we need to ensure that we compress the weight matrix only once and never decompress it, as decompression is currently unsupported. Before this change, using `SparseParameter(SparseTensor)` meant that in `MergedColumnParallelLinear` and `QKVParallelLinear` every time a new shard was loaded by the `weight_loader` (e.g., the "q" portion of `QKVParallelLinear`), we would decompress the tensor in-order to use narrow to update the appropriate section of the weight tensor. With this change, `SparseParameter(SparseTensor)` is replaced with `LazyCompressedParameter`, which allows us to operate on `uncompressed_data` until we explicitly compress it. At that point, the `uncompressed_data` is compressed into `compressed_data` and freed. Currently, the detection of when to call compress is somewhat hacky. For `QKVParallelLinear`, we compress only after inserting "q", "k", and "v" shard ids, and for `MergedColumnParallelLinear`, we compress once we've inserted the same number of shards as outputs (determined by `len(output_sizes)`), which implicitly assumes one shard per output. Moving away from `SparseParameter(SparseTensor)` means that `SparseTensor` no longer handles dispatching to the custom ops; instead, this is handled by `SparseW16A16LinearMethod`. I believe this is a positive change overall. `SparseTensor` was an unnecessary extra layer of abstraction/indirection originally designed for the SLoRA work, not vLLM. This did result in the 2:4 sparse implementation breaking. However, it turns out it was already broken (i.e., it was decompressing and running dense within `SparseTensor`), so we "disable" it for now ("disable" meaning decompress and run dense instead). We should revisit all of this infrastructure post-MVP. --------- Co-authored-by: Andrew Feldman <[email protected]>
- Loading branch information
1 parent
c45f20d
commit 00e7e17
Showing
7 changed files
with
148 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,5 @@ | ||
import torch | ||
from vllm.model_executor.layers.parameters.sparsity import SparseParameter | ||
from vllm.model_executor.layers.parameters.lazy_compressed import LazyCompressedParameter | ||
|
||
|
||
def get_param_data(param: torch.nn.Parameter) -> torch.Tensor: | ||
"""Gets parameter data in dense format.""" | ||
if isinstance(param, SparseParameter): | ||
return param.get_dense_data() | ||
else: | ||
return param.data | ||
__all__ = [ | ||
"LazyCompressedParameter", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy | ||
import torch | ||
from torch.utils._pytree import tree_map | ||
|
||
from typing import Type | ||
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat) | ||
|
||
|
||
class LazyCompressedParameter(torch.Tensor): | ||
|
||
@staticmethod | ||
def __new__(cls, | ||
uncompressed_data: torch.Tensor, | ||
storage_format_cls: Type[ | ||
CompressedStorageFormat] = SparseBitmaskStorageFormat, | ||
compress_transposed: bool = False): | ||
self = torch.Tensor._make_wrapper_subclass( | ||
cls, | ||
size=uncompressed_data.shape, | ||
dtype=uncompressed_data.dtype, | ||
requires_grad=False) | ||
self.storage_format_cls = storage_format_cls | ||
self.compressed_data = None | ||
self.uncompressed_data = uncompressed_data | ||
self.compress_transposed = compress_transposed | ||
self._is_param = True | ||
|
||
return self | ||
|
||
@property | ||
def has_compressed_data(self) -> bool: | ||
return (self.compressed_data is not None) | ||
|
||
@property | ||
def has_uncompressed_data(self) -> bool: | ||
return (self.uncompressed_data is not None) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs): | ||
ret_storage_format_cls = None | ||
|
||
def unwrap(e): | ||
nonlocal ret_storage_format_cls | ||
if isinstance(e, LazyCompressedParameter): | ||
assert ret_storage_format_cls is None or ret_storage_format_cls == e.storage_format_cls | ||
ret_storage_format_cls = e.storage_format_cls | ||
return e.uncompressed_data if isinstance( | ||
e, LazyCompressedParameter) else e | ||
|
||
rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) | ||
|
||
def wrap(e): | ||
if isinstance(e, | ||
torch.Tensor) and ret_storage_format_cls is not None: | ||
return LazyCompressedParameter( | ||
e, storage_format_cls=ret_storage_format_cls) | ||
return e | ||
|
||
rs = tree_map(wrap, rs) | ||
return rs | ||
|
||
def compress(self) -> None: | ||
density = torch.count_nonzero( | ||
self.uncompressed_data).item() / numpy.prod(self.shape) | ||
|
||
# only compress if we have sufficient sparsity (>=45%), currently | ||
# this applies globally across all formats including 2:4 | ||
if (1 - density) < 0.45: | ||
return | ||
|
||
if self.uncompressed_data is None: | ||
raise ValueError( | ||
"Called compress() but uncompressed_data does not exist.") | ||
self.compressed_data = self.storage_format_cls.compress( | ||
self.uncompressed_data.t( | ||
) if self.compress_transposed else self.uncompressed_data) | ||
del self.uncompressed_data # free memory | ||
self.uncompressed_data = None |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.