Skip to content

Commit

Permalink
update subset_transform (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang authored Nov 21, 2024
1 parent 5197a64 commit 135fe9c
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 37 deletions.
10 changes: 6 additions & 4 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,15 @@ def block_transform(self, block, input_feat, block_kwargs):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
layers_dict = subset['layers']
prev_op = subset['prev_op']
input_name = subset['input'][0]
inspect_module = subset['inspect']

if not check_do_quant(
self.block_idx,
list(layers_dict.keys())[0],
Expand Down
7 changes: 1 addition & 6 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,8 @@ def block_transform(self, block, input_feat, block_kwargs):

for index, subset in enumerate(subsets):
logger.info(f'subset: {subset}')
prev_op = subset['prev_op']
layers_dict = subset['layers']
input_name = subset['input'][0]
inspect_module = subset['inspect']
inspect_has_kwargs = subset['has_kwargs']
if inspect_has_kwargs:
if 'sub_keys' in subset:
Expand All @@ -478,11 +476,8 @@ def block_transform(self, block, input_feat, block_kwargs):
else:
subset_kwargs = {}
self.subset_transform(
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
)
if self.act_static:
Expand Down
9 changes: 5 additions & 4 deletions llmc/compression/quantization/dgq.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,14 @@ def search_scale_zero_subset(self, layers_dict, input_feat):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
layers_dict = subset['layers']
prev_op = subset['prev_op']
input_name = subset['input'][0]

layers = list(layers_dict.values())
if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
self.smoothquant_transform(prev_op, layers, input_feat[input_name])
Expand Down
8 changes: 7 additions & 1 deletion llmc/compression/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ def block_transform(self, block, input_feat, block_kwargs):
super().block_transform(block, input_feat, block_kwargs)

@torch.no_grad()
def subset_transform(self, layers_dict, *args, **kwargs):
def subset_transform(
self,
subset,
input_feat,
subset_kwargs,
):
layers_dict = subset['layers']
for name in layers_dict:
layer = layers_dict[name]
self.layer_transform(layer, name)
Expand Down
10 changes: 6 additions & 4 deletions llmc/compression/quantization/osplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,15 @@ def search_scale_shift_subset(
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
layers_dict = subset['layers']
prev_op = subset['prev_op']
input_name = subset['input'][0]
inspect_module = subset['inspect']

assert (
len(prev_op) == 1
), 'Only support single prev_op. If multi prev_ops, code need to be updated.'
Expand Down
5 changes: 1 addition & 4 deletions llmc/compression/quantization/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ def block_opt(self, *opt_kwargs):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
pass
9 changes: 5 additions & 4 deletions llmc/compression/quantization/smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ def search_scale_subset(self, layers, tensors):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
layers_dict = subset['layers']
prev_op = subset['prev_op']
input_name = subset['input'][0]

if not self.filter_subset(prev_op):
logger.info('Do not transform this subset.')
return
Expand Down
9 changes: 4 additions & 5 deletions llmc/compression/sparsification/magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ def __init__(self, model, sparsity_config, input, padding_mask, config, modality
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs
subset_kwargs,
):
layers_dict = subset['layers']

layers = list(layers_dict.values())
for layer in layers:
W = layer.weight.data
Expand Down
10 changes: 5 additions & 5 deletions llmc/compression/sparsification/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def get_row_scale(self, layer, act):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs
subset_kwargs,
):
layers_dict = subset['layers']
input_name = subset['input'][0]

layers = list(layers_dict.values())
for layer in layers:
scaler_row = self.get_row_scale(layer, input_feat[input_name][0])
Expand Down

0 comments on commit 135fe9c

Please sign in to comment.