diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index d30b7e28..3f0c0d1c 100644 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -195,13 +195,15 @@ def block_transform(self, block, input_feat, block_kwargs): @torch.no_grad() def subset_transform( self, - layers_dict, + subset, input_feat, - prev_op, - input_name, - inspect_module, subset_kwargs, ): + layers_dict = subset['layers'] + prev_op = subset['prev_op'] + input_name = subset['input'][0] + inspect_module = subset['inspect'] + if not check_do_quant( self.block_idx, list(layers_dict.keys())[0], diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 8b7b2ad6..5178de8c 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -465,10 +465,8 @@ def block_transform(self, block, input_feat, block_kwargs): for index, subset in enumerate(subsets): logger.info(f'subset: {subset}') - prev_op = subset['prev_op'] layers_dict = subset['layers'] input_name = subset['input'][0] - inspect_module = subset['inspect'] inspect_has_kwargs = subset['has_kwargs'] if inspect_has_kwargs: if 'sub_keys' in subset: @@ -478,11 +476,8 @@ def block_transform(self, block, input_feat, block_kwargs): else: subset_kwargs = {} self.subset_transform( - layers_dict, + subset, input_feat, - prev_op, - input_name, - inspect_module, subset_kwargs, ) if self.act_static: diff --git a/llmc/compression/quantization/dgq.py b/llmc/compression/quantization/dgq.py index 9342a591..a39036f6 100644 --- a/llmc/compression/quantization/dgq.py +++ b/llmc/compression/quantization/dgq.py @@ -276,13 +276,14 @@ def search_scale_zero_subset(self, layers_dict, input_feat): @torch.no_grad() def subset_transform( self, - layers_dict, + subset, input_feat, - prev_op, - input_name, - inspect_module, subset_kwargs, ): + layers_dict = subset['layers'] + prev_op = subset['prev_op'] + input_name = subset['input'][0] + layers = list(layers_dict.values()) if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)): self.smoothquant_transform(prev_op, layers, input_feat[input_name]) diff --git a/llmc/compression/quantization/gptq.py b/llmc/compression/quantization/gptq.py index 8f5a9d5d..165b7f7a 100644 --- a/llmc/compression/quantization/gptq.py +++ b/llmc/compression/quantization/gptq.py @@ -89,7 +89,13 @@ def block_transform(self, block, input_feat, block_kwargs): super().block_transform(block, input_feat, block_kwargs) @torch.no_grad() - def subset_transform(self, layers_dict, *args, **kwargs): + def subset_transform( + self, + subset, + input_feat, + subset_kwargs, + ): + layers_dict = subset['layers'] for name in layers_dict: layer = layers_dict[name] self.layer_transform(layer, name) diff --git a/llmc/compression/quantization/osplus.py b/llmc/compression/quantization/osplus.py index c6003faf..c8696c13 100644 --- a/llmc/compression/quantization/osplus.py +++ b/llmc/compression/quantization/osplus.py @@ -171,13 +171,15 @@ def search_scale_shift_subset( @torch.no_grad() def subset_transform( self, - layers_dict, + subset, input_feat, - prev_op, - input_name, - inspect_module, subset_kwargs, ): + layers_dict = subset['layers'] + prev_op = subset['prev_op'] + input_name = subset['input'][0] + inspect_module = subset['inspect'] + assert ( len(prev_op) == 1 ), 'Only support single prev_op. If multi prev_ops, code need to be updated.' diff --git a/llmc/compression/quantization/rtn.py b/llmc/compression/quantization/rtn.py index 966524c8..d5777097 100644 --- a/llmc/compression/quantization/rtn.py +++ b/llmc/compression/quantization/rtn.py @@ -19,11 +19,8 @@ def block_opt(self, *opt_kwargs): @torch.no_grad() def subset_transform( self, - layers_dict, + subset, input_feat, - prev_op, - input_name, - inspect_module, subset_kwargs, ): pass diff --git a/llmc/compression/quantization/smoothquant.py b/llmc/compression/quantization/smoothquant.py index 3c47b4bc..d238f4e5 100644 --- a/llmc/compression/quantization/smoothquant.py +++ b/llmc/compression/quantization/smoothquant.py @@ -61,13 +61,14 @@ def search_scale_subset(self, layers, tensors): @torch.no_grad() def subset_transform( self, - layers_dict, + subset, input_feat, - prev_op, - input_name, - inspect_module, subset_kwargs, ): + layers_dict = subset['layers'] + prev_op = subset['prev_op'] + input_name = subset['input'][0] + if not self.filter_subset(prev_op): logger.info('Do not transform this subset.') return diff --git a/llmc/compression/sparsification/magnitude.py b/llmc/compression/sparsification/magnitude.py index af3712c4..9057ec1a 100644 --- a/llmc/compression/sparsification/magnitude.py +++ b/llmc/compression/sparsification/magnitude.py @@ -14,13 +14,12 @@ def __init__(self, model, sparsity_config, input, padding_mask, config, modality @torch.no_grad() def subset_transform( self, - layers_dict, + subset, input_feat, - prev_op, - input_name, - inspect_module, - subset_kwargs + subset_kwargs, ): + layers_dict = subset['layers'] + layers = list(layers_dict.values()) for layer in layers: W = layer.weight.data diff --git a/llmc/compression/sparsification/wanda.py b/llmc/compression/sparsification/wanda.py index c0cfe710..01c88720 100644 --- a/llmc/compression/sparsification/wanda.py +++ b/llmc/compression/sparsification/wanda.py @@ -33,13 +33,13 @@ def get_row_scale(self, layer, act): @torch.no_grad() def subset_transform( self, - layers_dict, + subset, input_feat, - prev_op, - input_name, - inspect_module, - subset_kwargs + subset_kwargs, ): + layers_dict = subset['layers'] + input_name = subset['input'][0] + layers = list(layers_dict.values()) for layer in layers: scaler_row = self.get_row_scale(layer, input_feat[input_name][0])