Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Abf149/fix semi structured sparse (#16)
Browse files Browse the repository at this point in the history
SUMMARY:
- Fix bug whereby 2:4 is not being invoked
- Eschew SparseTensor based implementation

TESTING:
- examples/offline_inference_semi_structured_sparse.py

---------

Co-authored-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
afeldman-nm and LucasWilkinson authored Feb 16, 2024
1 parent 4f225b4 commit 8599f81
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.parameters import LazyCompressedParameter
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat)
from magic_wand.semi_structured import (pad_tensor_to_multiple,
extract_valid_rows)
from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat,
SparseSemiStructuredStorageFormat)
from magic_wand.ops import be_ds_gemm


Expand Down Expand Up @@ -54,11 +57,18 @@ def apply_weights(
if w.has_uncompressed_data:
assert not w.has_compressed_data
output = F.linear(x, w.uncompressed_data, bias)
# The current 2:4 implementation was running dense so ignore it
# for now and instead just explicitly decompress as usual
# elif self.storage_format_cls == SparseSemiStructuredStorageFormat:
# assert bias is None
# raise NotImplementedError
elif self.storage_format_cls == SparseSemiStructuredStorageFormat:
assert bias is None
w_encap = w.compressed_data.encapsulated_torch_sparse_tensor
out_shape = (x.shape[:-1] + (w_encap.shape[0], ))
reshaped_x, valid_rows_range = pad_tensor_to_multiple(
x.reshape(-1, x.shape[-1]), 8)
output = F.linear(
reshaped_x, w_encap,
torch.nn.Parameter(torch.zeros((w_encap.shape[0], ))).to(
reshaped_x.dtype).to(reshaped_x.device)).contiguous()
output = extract_valid_rows(output, valid_rows_range)
return output.reshape(out_shape)
elif self.storage_format_cls == SparseBEGemmStorageFormat:
assert bias is None
assert w.compress_transposed
Expand Down

0 comments on commit 8599f81

Please sign in to comment.