diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index b96c915dceca8..2844018cc98cb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -92,14 +92,40 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight_packed_data = layer.weight_packed.data meta = layer.meta.data - weight_data = { - "weight_packed": weight_packed_data, - "meta": meta - } - #decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data).contiguous() - # Temporarily swap in to use Alex's method. Seems like the compression might be wrong? - decompress = sparse_semi_structured_to_dense_cutlass(weight_packed_data, meta) - compressed = compress_to_torch_sparse_semi_structured_mat(decompress) + qkv_sizes = [2048, 256, 256] + gate_up_sizes = [5632, 5632] + split_weights = None + split_meta = None + + def _process_split(input_weight, input_meta): + weight_data = { + "weight_packed": input_weight, + "meta": input_meta + } + decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data) + return decompress + + print(self.layer_name) + if "qkv" in self.layer_name: + split_weights = torch.split(weight_packed_data, qkv_sizes) + split_meta = torch.split(meta, qkv_sizes) + elif "gate_up" in self.layer_name: + split_weights = torch.split(weight_packed_data, gate_up_sizes) + split_meta = torch.split(meta, gate_up_sizes) + + if split_weights: + all_compress = [] + for i in range(len(split_weights)): + print(split_weights[i].shape, split_meta[i].shape) + compress_i = _process_split(split_weights[i], split_meta[i]) + all_compress.append(compress_i) + + compressed = torch.cat(all_compress) + compressed = compress_to_torch_sparse_semi_structured_mat(compressed) + else: + decompress = _process_split(weight_packed_data, meta) + compressed = compress_to_torch_sparse_semi_structured_mat(decompress) + layer.weight = Parameter(compressed, requires_grad=False) else: