diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py index 65713a1bf15b3..b194e984a9254 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -6,7 +6,10 @@ from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from vllm.model_executor.layers.parameters import LazyCompressedParameter -from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat) +from magic_wand.semi_structured import (pad_tensor_to_multiple, + extract_valid_rows) +from magic_wand import (CompressedStorageFormat, SparseBEGemmStorageFormat, + SparseSemiStructuredStorageFormat) from magic_wand.ops import be_ds_gemm @@ -54,11 +57,18 @@ def apply_weights( if w.has_uncompressed_data: assert not w.has_compressed_data output = F.linear(x, w.uncompressed_data, bias) - # The current 2:4 implementation was running dense so ignore it - # for now and instead just explicitly decompress as usual - # elif self.storage_format_cls == SparseSemiStructuredStorageFormat: - # assert bias is None - # raise NotImplementedError + elif self.storage_format_cls == SparseSemiStructuredStorageFormat: + assert bias is None + w_encap = w.compressed_data.encapsulated_torch_sparse_tensor + out_shape = (x.shape[:-1] + (w_encap.shape[0], )) + reshaped_x, valid_rows_range = pad_tensor_to_multiple( + x.reshape(-1, x.shape[-1]), 8) + output = F.linear( + reshaped_x, w_encap, + torch.nn.Parameter(torch.zeros((w_encap.shape[0], ))).to( + reshaped_x.dtype).to(reshaped_x.device)).contiguous() + output = extract_valid_rows(output, valid_rows_range) + return output.reshape(out_shape) elif self.storage_format_cls == SparseBEGemmStorageFormat: assert bias is None assert w.compress_transposed