Skip to content

Commit

Permalink
patch
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Nov 4, 2024
1 parent 0702869 commit 49bdcf5
Showing 1 changed file with 32 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,39 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_packed_data = layer.weight_packed.data
meta = layer.meta.data

weight_data = {
"weight_packed": weight_packed_data,
"meta": meta
}
#decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data).contiguous()
qkv_sizes = [2048, 256, 256]
gate_up_sizes = [5632, 5632]
split_weights = None
split_meta = None

def _process_split(input_weight, input_meta):
decompress = sparse_semi_structured_to_dense_cutlass(input_weight, input_meta)
return decompress

print(self.layer_name)
if "qkv" in self.layer_name:
split_weights = torch.split(weight_packed_data, qkv_sizes)
split_meta = torch.split(meta, qkv_sizes)
elif "gate_up" in self.layer_name:
split_weights = torch.split(weight_packed_data, gate_up_sizes)
split_meta = torch.split(meta, gate_up_sizes)

if split_weights:
all_compress = []
for i in range(len(split_weights)):
print(split_weights[i].shape, split_meta[i].shape)
compress_i = _process_split(split_weights[i], split_meta[i])
all_compress.append(compress_i)

compressed = torch.cat(all_compress)
compressed = compress_to_torch_sparse_semi_structured_mat(compressed)
else:
decompress = sparse_semi_structured_to_dense_cutlass(weight_packed_data, meta)
compressed = compress_to_torch_sparse_semi_structured_mat(decompress)

#decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data)
# Temporarily swap in to use Alex's method. Seems like the compression might be wrong?
decompress = sparse_semi_structured_to_dense_cutlass(weight_packed_data, meta)
compressed = compress_to_torch_sparse_semi_structured_mat(decompress)

layer.weight = Parameter(compressed, requires_grad=False)

else:
Expand Down

0 comments on commit 49bdcf5

Please sign in to comment.