diff --git a/configs/sparsification/ShortGPT/shortgpt.yml b/configs/sparsification/ShortGPT/shortgpt.yml new file mode 100644 index 00000000..f651e92a --- /dev/null +++ b/configs/sparsification/ShortGPT/shortgpt.yml @@ -0,0 +1,30 @@ +base: + seed: &seed 42 +model: + type: Llama + path: model path + torch_dtype: auto +calib: + name: pileval + download: False + path: calib data path + n_samples: 128 + bs: -1 + seq_len: 512 + preproc: general + seed: *seed +eval: + eval_pos: [transformed] + name: [wikitext2, c4] + download: False + path: eval data path + seq_len: 2048 +sparse: + method: ShortGPT + weight: + n_prune_layers: 9 +save: + save_trans: True + save_fp: False + save_lightllm: False + save_path: ./save diff --git a/docs/en/source/advanced/sparsification.md b/docs/en/source/advanced/sparsification.md index c3b7365a..88b26f08 100644 --- a/docs/en/source/advanced/sparsification.md +++ b/docs/en/source/advanced/sparsification.md @@ -1,6 +1,6 @@ # Model Sparsification -The llmc is currently gradually supporting sparse methods, having already implemented Magnitude and Wanda, and will support more algorithms in the future. +The llmc is currently gradually supporting sparse methods, having already implemented Magnitude, Wanda, and ShortGPT, and will support more algorithms in the future. Here is a sample of Wanda's settings: diff --git a/docs/zh_cn/source/advanced/sparsification.md b/docs/zh_cn/source/advanced/sparsification.md index e15f5664..3e82c266 100644 --- a/docs/zh_cn/source/advanced/sparsification.md +++ b/docs/zh_cn/source/advanced/sparsification.md @@ -1,6 +1,6 @@ # 模型稀疏化 -llmc目前正在逐渐支持稀疏化方法,目前已经实现了Magnitude和Wanda,将在未来支持更多的算法。 +llmc目前正在逐渐支持稀疏化方法,目前已经实现了Magnitude,Wanda和ShortGPT将在未来支持更多的算法。 以下是Wanda的设置样例: diff --git a/llmc/__main__.py b/llmc/__main__.py index a316e315..5ae500c8 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -47,7 +47,6 @@ def main(config): for ppl_eval in eval_list: ppl = ppl_eval.eval(model) logger.info(f'{ppl_eval.dataset} ppl : {ppl}') - sparsification = None if not config.get('calib', False): blockwise_opt = ALGO_REGISTRY[config.quant.method]( model, quant_config=config.quant, input=None, config=config @@ -61,20 +60,17 @@ def main(config): gc.collect() torch.cuda.empty_cache() if not config.get('sparse', False): - sparsification = False blockwise_opt = ALGO_REGISTRY[config.quant.method]( model, config.quant, model.get_first_block_input(), config ) else: - sparsification = True blockwise_opt = ALGO_REGISTRY[config.sparse.method]( model, config.sparse, model.get_first_block_input(), config ) blockwise_opt.run_block_loop() if 'eval' in config and 'transformed' in config.eval.eval_pos: - if not sparsification: - blockwise_opt.deploy('origin_float') + blockwise_opt.deploy('origin_float') for ppl_eval in eval_list: ppl = ppl_eval.eval(model) logger.info(f'{ppl_eval.dataset} ppl : {ppl}') diff --git a/llmc/compression/sparsification/__init__.py b/llmc/compression/sparsification/__init__.py index ead12f1e..b09a1347 100644 --- a/llmc/compression/sparsification/__init__.py +++ b/llmc/compression/sparsification/__init__.py @@ -1,4 +1,5 @@ from .base_blockwise_sparsification import BaseBlockwiseSparsification from .magnitude import Magnitude +from .shortgpt import ShortGPT from .sparse import Sparser from .wanda import Wanda diff --git a/llmc/compression/sparsification/base_blockwise_sparsification.py b/llmc/compression/sparsification/base_blockwise_sparsification.py index 0dfe78cc..607e244f 100644 --- a/llmc/compression/sparsification/base_blockwise_sparsification.py +++ b/llmc/compression/sparsification/base_blockwise_sparsification.py @@ -5,6 +5,8 @@ import torch from loguru import logger +from llmc.utils import copy_files + from ..blockwise_optimization import BlockwiseOpt from .sparse import Sparser @@ -18,15 +20,15 @@ def block_init(self, block): pass def set_sparsity_config(self): - if ( - 'sparsity_out' in self.sparsity_config - and self.sparsity_config['sparsity_out'] - ): + if 'sparsity_out' in self.sparsity_config and self.sparsity_config[ + 'sparsity_out' + ]: self.sparsity_out = True else: self.sparsity_out = False logger.info(f'use sparsity_out {self.sparsity_out}') - self.sparser = Sparser(**self.sparsity_config['weight']) + + self.sparser = Sparser(self.sparsity_config['weight']) def block_forward(self, block, input_data=None): output = [] @@ -35,10 +37,9 @@ def block_forward(self, block, input_data=None): for i in range(len(input_data)): input_data[i] = input_data[i].to(device=next(block.parameters()).device) - if ( - 'attention_mask' in self.input['kwargs'][i] - and self.input['kwargs'][i]['attention_mask'] is not None - ): + if 'attention_mask' in self.input[ + 'kwargs' + ][i] and self.input['kwargs'][i]['attention_mask'] is not None: self.input['kwargs'][i]['attention_mask'] = self.input['kwargs'][i][ 'attention_mask' ].cuda() @@ -47,10 +48,10 @@ def block_forward(self, block, input_data=None): output.append(out) return output - def block_opt(self, block, idx): + def block_opt(self, block): block = block.cuda() named_linears = self.model.get_block_linears(block) - # logger.info(f"named_linears: {named_linears}") + logger.info(f'named_linears: {named_linears}') input_feat = defaultdict(list) handles = [] self.block_init(block) @@ -72,7 +73,7 @@ def block_opt(self, block, idx): h.remove() torch.cuda.empty_cache() - self.block_transform(block, input_feat, idx, self.input['kwargs']) + self.block_transform(block, input_feat, self.input['kwargs']) if self.sparsity_out: self.input['data'] = self.block_forward(block) @@ -82,8 +83,8 @@ def block_opt(self, block, idx): gc.collect() torch.cuda.empty_cache() - def block_transform(self, block, input_feat, idx, block_kwargs): - logger.info(f'Start transform the {idx+1}-th block') + def block_transform(self, block, input_feat, block_kwargs): + logger.info(f'Start transform the {self.block_idx+1}-th block') subsets = self.model.get_subsets_in_block(block) for index, subset in enumerate(subsets): if not self.filter_subset(subset): @@ -101,19 +102,35 @@ def block_transform(self, block, input_feat, idx, block_kwargs): prev_op, input_name, inspect_module, - subset_kwargs, - idx, + subset_kwargs ) - logger.info(f'End transform the {idx+1}-th block') + logger.info(f'End transform the {self.block_idx+1}-th block') def filter_subset(self, subset): return True - # todo @torch.no_grad() - def deploy(self): + def deploy(self, deploy_format): logger.info('-- deploy_sparsity_model start --') logger.info(f'sparsity_config : {self.sparsity_config}') - # self.model.replace_module_all(module, params_dict) logger.info('-- deploy_sparsity_model done --') + + @torch.no_grad() + def copy_tokenizer(self, path): + for substring in self.config.save.get('tokenizer_file_substring', ['token']): + copy_files(self.config.model.path, path, substring) + logger.info('copy tokenizer done --') + + @torch.no_grad() + def save_model(self, path): + if self.config.model.type == 'Llava': + self.model.llava_model.language_model = self.model.get_model() + self.model.llava_model.save_pretrained(path) + logger.info('save model done --') + self.copy_tokenizer(path) + copy_files(self.config.model.path, path, 'preprocessor_config') + else: + self.model.get_model().save_pretrained(path) + logger.info('save model done --') + self.copy_tokenizer(path) diff --git a/llmc/compression/sparsification/magnitude.py b/llmc/compression/sparsification/magnitude.py index d6f285fe..57ad23a8 100644 --- a/llmc/compression/sparsification/magnitude.py +++ b/llmc/compression/sparsification/magnitude.py @@ -19,8 +19,7 @@ def subset_transform( prev_op, input_name, inspect_module, - subset_kwargs, - idx, + subset_kwargs ): layers = list(layers_dict.values()) for layer in layers: diff --git a/llmc/compression/sparsification/shortgpt.py b/llmc/compression/sparsification/shortgpt.py new file mode 100644 index 00000000..64aadd9f --- /dev/null +++ b/llmc/compression/sparsification/shortgpt.py @@ -0,0 +1,114 @@ +import gc +import json +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from loguru import logger +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from transformers.models.mistral.modeling_mistral import MistralRMSNorm + +from llmc.utils import copy_files +from llmc.utils.registry_factory import ALGO_REGISTRY + +from .base_blockwise_sparsification import BaseBlockwiseSparsification + + +@ALGO_REGISTRY +class ShortGPT(BaseBlockwiseSparsification): + def __init__(self, model, sparsity_config, input, config): + super().__init__(model, sparsity_config, input, config) + + def block_opt(self, block): + block = block.cuda() + + output_feat = self.block_forward(block) + torch.cuda.empty_cache() + self.block_transform(self.input['data'], output_feat) + self.input['data'] = output_feat + + def block_transform(self, input_feat, output_feat): + logger.info(f'Start transform the {self.block_idx+1}-th block') + self.subset_transform( + input_feat, + output_feat + ) + + @torch.no_grad() + def compute_bi( + self, + input_feat: torch.Tensor, + output_feat: torch.Tensor + ): + _, _, d = input_feat.shape + input_feat = input_feat.reshape(-1, d) + output_feat = output_feat.reshape(-1, d) + + norm_input = input_feat.norm(dim=-1, keepdim=True) + norm_output = output_feat.norm(dim=-1, keepdim=True) + + sim = (input_feat @ output_feat.T) / (norm_input * norm_output) + sim = sim.diagonal().nan_to_num(nan=0.5) + + return 1 - sim + + @torch.no_grad() + def subset_transform( + self, + input_feat, + output_feat + ): + # calculate BI score + if self.sparser.importances is None: + self.sparser.importances = np.zeros(len(self.blocks)) + self.sparser.importances[self.block_idx] = self.compute_bi( + input_feat[0], output_feat[0] + ).sum().cpu().item() + + @torch.no_grad() + def remove_layers( + self, + layers_to_remove: Optional[List[int]] = [] + ): + if not layers_to_remove and self.sparser.n_prune_layers: + layers_to_remove = np.argsort( + np.array(self.sparser.importances) + )[:self.sparser.n_prune_layers].tolist() + + for idx in sorted(layers_to_remove, reverse=True): + try: + del self.blocks[idx] + except IndexError: + logger.info(f'layer {idx} does not exist') + return layers_to_remove + + @torch.no_grad() + def deploy(self, deploy_format): + logger.info(f'After compute, BI scores are {self.sparser.importances}') + logger.info('-- deploy_sparsity_model start --') + logger.info(f'sparsity_config : {self.sparsity_config}') + logger.info('-- begin remove layers --') + layers_to_remove = self.remove_layers() + logger.info(f'remove layers: {layers_to_remove}') + logger.info('-- deploy_sparsity_model done --') + + @torch.no_grad() + def save_model(self, path): + if self.config.model.type == 'Llava': + self.model.llava_model.language_model = self.model.get_model() + self.model.llava_model.save_pretrained(path) + logger.info('save model done --') + self.copy_tokenizer(path) + copy_files(self.config.model.path, path, 'preprocessor_config') + else: + self.model.get_model().save_pretrained(path) + config_file = path + '/config.json' + + logger.info('save model done --') + self.copy_tokenizer(path) + with open(config_file, 'r') as file: + config_new = json.load(file) + config_new['num_hidden_layers'] = len(self.blocks) + with open(config_file, 'w') as file: + json.dump(config_new, file, indent=4) diff --git a/llmc/compression/sparsification/sparse.py b/llmc/compression/sparsification/sparse.py index 8efc5626..09553572 100644 --- a/llmc/compression/sparsification/sparse.py +++ b/llmc/compression/sparsification/sparse.py @@ -1,5 +1,9 @@ class Sparser: - def __init__(self, sparsity, **kwargs): - self.sparsity = sparsity + def __init__(self, sparsity_constraint, **kwargs): + if 'sparsity' in sparsity_constraint: + self.sparsity = sparsity_constraint['sparsity'] + self.W_mask = None + elif 'n_prune_layers' in sparsity_constraint: + self.n_prune_layers = sparsity_constraint['n_prune_layers'] + self.importances = None self.kwargs = kwargs - self.W_mask = None diff --git a/llmc/compression/sparsification/wanda.py b/llmc/compression/sparsification/wanda.py index 4a36afb9..1cdbc1e7 100644 --- a/llmc/compression/sparsification/wanda.py +++ b/llmc/compression/sparsification/wanda.py @@ -38,8 +38,7 @@ def subset_transform( prev_op, input_name, inspect_module, - subset_kwargs, - idx, + subset_kwargs ): layers = list(layers_dict.values()) for layer in layers: diff --git a/scripts/run_shortgpt_llama.sh b/scripts/run_shortgpt_llama.sh new file mode 100644 index 00000000..f56c090a --- /dev/null +++ b/scripts/run_shortgpt_llama.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +gpu_id=0 +export CUDA_VISIBLE_DEVICES=$gpu_id + +llmc=llmc_path +export PYTHONPATH=$llmc:$PYTHONPATH + +task_name=llm_quant_exp + +nohup \ +python -m llmc --config ../configs/sparsification/ShortGPT/shortgpt.yml \ +> ${task_name}.log 2>&1 & + +echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_wanda_llama.sh b/scripts/run_wanda_llama.sh new file mode 100644 index 00000000..96b31c51 --- /dev/null +++ b/scripts/run_wanda_llama.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +gpu_id=0 +export CUDA_VISIBLE_DEVICES=$gpu_id + +llmc=llmc_path +export PYTHONPATH=$llmc:$PYTHONPATH + +task_name=llm_quant_exp + +nohup \ +python -m llmc --config ../configs/sparsification/Wand/wanda.yml \ +> ${task_name}.log 2>&1 & + +echo $! > ${task_name}.pid \ No newline at end of file