Skip to content

Commit

Permalink
get uncompressed to work; update gemm to use contiguous; use alex's u…
Browse files Browse the repository at this point in the history
…tils instead of our decompressor
  • Loading branch information
dsikka committed Nov 1, 2024
1 parent f37ffd3 commit 0702869
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def get_scheme(
"""
scheme = CompressedTensors24(
model_compressor=self.model_compressor,
layer_name=layer_name
)

return scheme
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
compress_to_torch_sparse_semi_structured_mat,
semi_structured_dense_sparse_T_gemm
)

from torch.sparse import to_sparse_semi_structured
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import sparse_semi_structured_to_dense_cutlass, sparse_semi_structured_from_dense_cutlass
__all__ = ["CompressedTensors24"]

class CompressedTensors24(CompressedTensorsScheme):
def __init__(self, model_compressor: Optional[ModelCompressor] = None):
def __init__(self, model_compressor: Optional[ModelCompressor] = None, layer_name = None):
self.model_compressor = model_compressor

self.layer_name = layer_name

@classmethod
def get_min_capability(cls) -> int:
return 80
Expand All @@ -27,6 +28,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

compressed = True # toggle based on the case we're running
weights_dtype = params_dtype
weights = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
Expand All @@ -36,30 +38,41 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
output_dim=0,
weight_loader=weight_loader)

bits_per_weight_element = weights.itemsize * 8

meta_dtype = torch.int32 if bits_per_weight_element == 8 else torch.int16


meta_input_size = (
input_size_per_partition // 32
if bits_per_weight_element == 8
else input_size_per_partition // 16
)
meta = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
meta_input_size,
dtype=meta_dtype),
# parameter to store uncompressed weight or decompressed weight
weight_unpacked = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=weights_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

# TODO: replace weight_packed name, with something
# more meaningful, like sparse24_packed, this will
# require changes on compressed_tensors side

layer.register_parameter("weight_packed", weights)
layer.register_parameter("meta", meta)

# For the uncompressed case
if compressed:
bits_per_weight_element = weights.itemsize * 8
meta_dtype = torch.int32 if bits_per_weight_element == 8 else torch.int16

meta_input_size = (
input_size_per_partition // 32
if bits_per_weight_element == 8
else input_size_per_partition // 16
)
meta = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
meta_input_size,
dtype=meta_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

# TODO: replace weight_packed name, with something
# more meaningful, like sparse24_packed, this will
# require changes on compressed_tensors side

layer.register_parameter("weight_packed", weights)
layer.register_parameter("meta", meta)

layer.register_parameter("weight", weight_unpacked)


def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Expand All @@ -74,21 +87,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# replace by a better way to identify targetted params
# using COMPRESSION_PARAMS defined by sparse compressors
# and decompress the weights accordingly

if hasattr(layer, "weight_packed"):
# TODO: this name will also be changed to sparse24_packed
weight = layer.weight_packed.data
weight_packed_data = layer.weight_packed.data
meta = layer.meta.data

weight_data = {
"weight_packed": weight,
"weight_packed": weight_packed_data,
"meta": meta
}

decompressed_weight = self.model_compressor.sparsity_compressor.decompress_weight(weight_data)
decompressed_weight = decompressed_weight
compressed = compress_to_torch_sparse_semi_structured_mat(decompressed_weight)
layer.weight_packed = Parameter(compressed, requires_grad=False)
#decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data).contiguous()
# Temporarily swap in to use Alex's method. Seems like the compression might be wrong?
decompress = sparse_semi_structured_to_dense_cutlass(weight_packed_data, meta)
compressed = compress_to_torch_sparse_semi_structured_mat(decompress)
layer.weight = Parameter(compressed, requires_grad=False)

else:
# assume uncompressd case
# Proof that Alex's methods work: we can compress and decompress to get accurate generation using his methods below
# Would be equivalent to uncommenting out the next two lines and passing decompress into compress_to_torch_sparse_semi_structured_mat which also works
#comp, meta = sparse_semi_structured_from_dense_cutlass(layer.weight)
#decompress = sparse_semi_structured_to_dense_cutlass(comp, meta)
compressed = compress_to_torch_sparse_semi_structured_mat(layer.weight)
layer.weight = Parameter(compressed, requires_grad=False)

def apply_weights(self,
layer: torch.nn.Module,
Expand All @@ -105,20 +126,16 @@ def apply_weights(self,
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
result = semi_structured_dense_sparse_T_gemm(
a_dense=x,
b_T_packed=layer.weight_packed.data,
bias=bias,
)

has_nans = torch.any(torch.isnan(result))

assert not has_nans

print("Result: ", result)
print("+" * 10)
return result

""" debugging code
a_sparse = to_sparse_semi_structured(layer.weight)
result = torch.mm(a_sparse, x.t().contiguous())
return result.t().contiguous()
"""
return semi_structured_dense_sparse_T_gemm(
a_dense=x,
b_T_packed=layer.weight.data
)



Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def semi_structured_dense_sparse_T_gemm(a_dense: torch.Tensor,
Returns:
torch.Tensor - Result of matrix multiplication.
'''
return (semi_structured_sparse_dense_gemm(b_T_packed, a_dense.t(), bias)).t()
return (semi_structured_sparse_dense_gemm(b_T_packed, a_dense.t().contiguous(), bias)).t().contiguous()


def semi_structured_sparse_dense_gemm_scaled(a_packed: torch.Tensor,
Expand Down

0 comments on commit 0702869

Please sign in to comment.