diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b786f28e649d8..598dea3d91fd1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -359,6 +359,7 @@ def get_scheme( """ scheme = CompressedTensors24( model_compressor=self.model_compressor, + layer_name=layer_name ) return scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 6643396014e89..b96c915dceca8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -9,14 +9,15 @@ compress_to_torch_sparse_semi_structured_mat, semi_structured_dense_sparse_T_gemm ) - +from torch.sparse import to_sparse_semi_structured +from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import sparse_semi_structured_to_dense_cutlass, sparse_semi_structured_from_dense_cutlass __all__ = ["CompressedTensors24"] class CompressedTensors24(CompressedTensorsScheme): - def __init__(self, model_compressor: Optional[ModelCompressor] = None): + def __init__(self, model_compressor: Optional[ModelCompressor] = None, layer_name = None): self.model_compressor = model_compressor - - + self.layer_name = layer_name + @classmethod def get_min_capability(cls) -> int: return 80 @@ -27,6 +28,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + compressed = True # toggle based on the case we're running weights_dtype = params_dtype weights = ModelWeightParameter(data=torch.empty( sum(output_partition_sizes), @@ -36,30 +38,41 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_dim=0, weight_loader=weight_loader) - bits_per_weight_element = weights.itemsize * 8 - - meta_dtype = torch.int32 if bits_per_weight_element == 8 else torch.int16 - - - meta_input_size = ( - input_size_per_partition // 32 - if bits_per_weight_element == 8 - else input_size_per_partition // 16 - ) - meta = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - meta_input_size, - dtype=meta_dtype), + # parameter to store uncompressed weight or decompressed weight + weight_unpacked = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=weights_dtype), input_dim=1, output_dim=0, weight_loader=weight_loader) - - # TODO: replace weight_packed name, with something - # more meaningful, like sparse24_packed, this will - # require changes on compressed_tensors side - - layer.register_parameter("weight_packed", weights) - layer.register_parameter("meta", meta) + + # For the uncompressed case + if compressed: + bits_per_weight_element = weights.itemsize * 8 + meta_dtype = torch.int32 if bits_per_weight_element == 8 else torch.int16 + + meta_input_size = ( + input_size_per_partition // 32 + if bits_per_weight_element == 8 + else input_size_per_partition // 16 + ) + meta = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + meta_input_size, + dtype=meta_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + # TODO: replace weight_packed name, with something + # more meaningful, like sparse24_packed, this will + # require changes on compressed_tensors side + + layer.register_parameter("weight_packed", weights) + layer.register_parameter("meta", meta) + + layer.register_parameter("weight", weight_unpacked) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -74,21 +87,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # replace by a better way to identify targetted params # using COMPRESSION_PARAMS defined by sparse compressors # and decompress the weights accordingly - if hasattr(layer, "weight_packed"): # TODO: this name will also be changed to sparse24_packed - weight = layer.weight_packed.data + weight_packed_data = layer.weight_packed.data meta = layer.meta.data weight_data = { - "weight_packed": weight, + "weight_packed": weight_packed_data, "meta": meta } - - decompressed_weight = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) - decompressed_weight = decompressed_weight - compressed = compress_to_torch_sparse_semi_structured_mat(decompressed_weight) - layer.weight_packed = Parameter(compressed, requires_grad=False) + #decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data).contiguous() + # Temporarily swap in to use Alex's method. Seems like the compression might be wrong? + decompress = sparse_semi_structured_to_dense_cutlass(weight_packed_data, meta) + compressed = compress_to_torch_sparse_semi_structured_mat(decompress) + layer.weight = Parameter(compressed, requires_grad=False) + + else: + # assume uncompressd case + # Proof that Alex's methods work: we can compress and decompress to get accurate generation using his methods below + # Would be equivalent to uncommenting out the next two lines and passing decompress into compress_to_torch_sparse_semi_structured_mat which also works + #comp, meta = sparse_semi_structured_from_dense_cutlass(layer.weight) + #decompress = sparse_semi_structured_to_dense_cutlass(comp, meta) + compressed = compress_to_torch_sparse_semi_structured_mat(layer.weight) + layer.weight = Parameter(compressed, requires_grad=False) def apply_weights(self, layer: torch.nn.Module, @@ -105,20 +126,16 @@ def apply_weights(self, :param bias: The bias to be added to the output tensor :return: The output tensor of the layer """ - result = semi_structured_dense_sparse_T_gemm( - a_dense=x, - b_T_packed=layer.weight_packed.data, - bias=bias, - ) - - has_nans = torch.any(torch.isnan(result)) - - assert not has_nans - print("Result: ", result) - print("+" * 10) - return result - + """ debugging code + a_sparse = to_sparse_semi_structured(layer.weight) + result = torch.mm(a_sparse, x.t().contiguous()) + return result.t().contiguous() + """ + return semi_structured_dense_sparse_T_gemm( + a_dense=x, + b_T_packed=layer.weight.data + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/sparsity/utils/cusparse_2_4_utils.py b/vllm/model_executor/layers/sparsity/utils/cusparse_2_4_utils.py index 4afe325f4e3a9..1be817617e7dc 100644 --- a/vllm/model_executor/layers/sparsity/utils/cusparse_2_4_utils.py +++ b/vllm/model_executor/layers/sparsity/utils/cusparse_2_4_utils.py @@ -102,7 +102,7 @@ def semi_structured_dense_sparse_T_gemm(a_dense: torch.Tensor, Returns: torch.Tensor - Result of matrix multiplication. ''' - return (semi_structured_sparse_dense_gemm(b_T_packed, a_dense.t(), bias)).t() + return (semi_structured_sparse_dense_gemm(b_T_packed, a_dense.t().contiguous(), bias)).t().contiguous() def semi_structured_sparse_dense_gemm_scaled(a_packed: torch.Tensor,