Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update subset_transform #218

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,15 @@ def block_transform(self, block, input_feat, block_kwargs):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
layers_dict = subset['layers']
prev_op = subset['prev_op']
input_name = subset['input'][0]
inspect_module = subset['inspect']

if not check_do_quant(
self.block_idx,
list(layers_dict.keys())[0],
Expand Down
7 changes: 1 addition & 6 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,8 @@ def block_transform(self, block, input_feat, block_kwargs):

for index, subset in enumerate(subsets):
logger.info(f'subset: {subset}')
prev_op = subset['prev_op']
layers_dict = subset['layers']
input_name = subset['input'][0]
inspect_module = subset['inspect']
inspect_has_kwargs = subset['has_kwargs']
if inspect_has_kwargs:
if 'sub_keys' in subset:
Expand All @@ -478,11 +476,8 @@ def block_transform(self, block, input_feat, block_kwargs):
else:
subset_kwargs = {}
self.subset_transform(
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
)
if self.act_static:
Expand Down
9 changes: 5 additions & 4 deletions llmc/compression/quantization/dgq.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,14 @@ def search_scale_zero_subset(self, layers_dict, input_feat):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
layers_dict = subset['layers']
prev_op = subset['prev_op']
input_name = subset['input'][0]

layers = list(layers_dict.values())
if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
self.smoothquant_transform(prev_op, layers, input_feat[input_name])
Expand Down
8 changes: 7 additions & 1 deletion llmc/compression/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ def block_transform(self, block, input_feat, block_kwargs):
super().block_transform(block, input_feat, block_kwargs)

@torch.no_grad()
def subset_transform(self, layers_dict, *args, **kwargs):
def subset_transform(
self,
subset,
input_feat,
subset_kwargs,
):
layers_dict = subset['layers']
for name in layers_dict:
layer = layers_dict[name]
self.layer_transform(layer, name)
Expand Down
10 changes: 6 additions & 4 deletions llmc/compression/quantization/osplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,15 @@ def search_scale_shift_subset(
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
layers_dict = subset['layers']
prev_op = subset['prev_op']
input_name = subset['input'][0]
inspect_module = subset['inspect']

assert (
len(prev_op) == 1
), 'Only support single prev_op. If multi prev_ops, code need to be updated.'
Expand Down
5 changes: 1 addition & 4 deletions llmc/compression/quantization/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ def block_opt(self, *opt_kwargs):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
pass
9 changes: 5 additions & 4 deletions llmc/compression/quantization/smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ def search_scale_subset(self, layers, tensors):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
):
layers_dict = subset['layers']
prev_op = subset['prev_op']
input_name = subset['input'][0]

if not self.filter_subset(prev_op):
logger.info('Do not transform this subset.')
return
Expand Down
9 changes: 4 additions & 5 deletions llmc/compression/sparsification/magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ def __init__(self, model, sparsity_config, input, padding_mask, config, modality
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs
subset_kwargs,
):
layers_dict = subset['layers']

layers = list(layers_dict.values())
for layer in layers:
W = layer.weight.data
Expand Down
10 changes: 5 additions & 5 deletions llmc/compression/sparsification/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def get_row_scale(self, layer, act):
@torch.no_grad()
def subset_transform(
self,
layers_dict,
subset,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs
subset_kwargs,
):
layers_dict = subset['layers']
input_name = subset['input'][0]

layers = list(layers_dict.values())
for layer in layers:
scaler_row = self.get_row_scale(layer, input_feat[input_name][0])
Expand Down
Loading