Skip to content

Commit

Permalink
add ShortGPT
Browse files Browse the repository at this point in the history
  • Loading branch information
MercuryB1 committed Aug 2, 2024
1 parent 339e151 commit 0cd2db8
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 62 deletions.
30 changes: 30 additions & 0 deletions configs/sparsification/ShortGPT/shortgpt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
base:
seed: &seed 42
model:
type: Llama
path: model path
torch_dtype: auto
calib:
name: pileval
download: False
path: calib data path
n_samples: 128
bs: -1
seq_len: 512
preproc: general
seed: *seed
eval:
eval_pos: [transformed]
name: [wikitext2, c4]
download: False
path: eval data path
seq_len: 2048
sparse:
method: ShortGPT
weight:
n_prune_layers: 9
save:
save_trans: True
save_fp: False
save_lightllm: False
save_path: ./save
2 changes: 1 addition & 1 deletion docs/en/source/advanced/sparsification.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Model Sparsification

The llmc is currently gradually supporting sparse methods, having already implemented Magnitude and Wanda, and will support more algorithms in the future.
The llmc is currently gradually supporting sparse methods, having already implemented Magnitude, Wanda, and ShortGPT, and will support more algorithms in the future.

Here is a sample of Wanda's settings:

Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/source/advanced/sparsification.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 模型稀疏化

llmc目前正在逐渐支持稀疏化方法,目前已经实现了Magnitude和Wanda,将在未来支持更多的算法
llmc目前正在逐渐支持稀疏化方法,目前已经实现了Magnitude,Wanda和ShortGPT将在未来支持更多的算法

以下是Wanda的设置样例:

Expand Down
6 changes: 1 addition & 5 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def main(config):
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
sparsification = None
if not config.get('calib', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model, quant_config=config.quant, input=None, config=config
Expand All @@ -61,20 +60,17 @@ def main(config):
gc.collect()
torch.cuda.empty_cache()
if not config.get('sparse', False):
sparsification = False
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model, config.quant, model.get_first_block_input(), config
)
else:
sparsification = True
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
model, config.sparse, model.get_first_block_input(), config
)
blockwise_opt.run_block_loop()

if 'eval' in config and 'transformed' in config.eval.eval_pos:
if not sparsification:
blockwise_opt.deploy('origin_float')
blockwise_opt.deploy('origin_float')
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
logger.info(f'{ppl_eval.dataset} ppl : {ppl}')
Expand Down
1 change: 1 addition & 0 deletions llmc/compression/sparsification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .magnitude import Magnitude
from .sparse import Sparser
from .wanda import Wanda
from .shortgpt import ShortGPT
111 changes: 64 additions & 47 deletions llmc/compression/sparsification/base_blockwise_sparsification.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,54 @@
import functools
import gc
from loguru import logger
from collections import defaultdict

import functools
import torch
from loguru import logger

from ..blockwise_optimization import BlockwiseOpt
import gc
from .sparse import Sparser

from llmc.utils import copy_files
from ..blockwise_optimization import BlockwiseOpt

class BaseBlockwiseSparsification(BlockwiseOpt):
def __init__(self, model, sparsity_config, input, config):
super().__init__(model, sparsity_config, input, config)
self.set_sparsity_config()

def block_init(self, block):
pass
pass

def set_sparsity_config(self):
if (
'sparsity_out' in self.sparsity_config
and self.sparsity_config['sparsity_out']
):
if "sparsity_out" in self.sparsity_config and self.sparsity_config['sparsity_out']:
self.sparsity_out = True
else:
self.sparsity_out = False
logger.info(f'use sparsity_out {self.sparsity_out}')
self.sparser = Sparser(**self.sparsity_config['weight'])
logger.info(f"use sparsity_out {self.sparsity_out}")

self.sparser = Sparser(self.sparsity_config["weight"])




def block_forward(self, block, input_data=None):
output = []
if input_data is None:
input_data = self.input['data']

input_data = self.input["data"]
for i in range(len(input_data)):
input_data[i] = input_data[i].to(device=next(block.parameters()).device)
if (
'attention_mask' in self.input['kwargs'][i]
and self.input['kwargs'][i]['attention_mask'] is not None
):
self.input['kwargs'][i]['attention_mask'] = self.input['kwargs'][i][
'attention_mask'
if "attention_mask" in self.input["kwargs"][i] and self.input["kwargs"][i]["attention_mask"] is not None:
self.input["kwargs"][i]["attention_mask"] = self.input["kwargs"][i][
"attention_mask"
].cuda()
with torch.no_grad():
out = block(input_data[i], **self.input['kwargs'][i])[0]
out = block(input_data[i], **self.input["kwargs"][i])[0]
output.append(out)
return output

def block_opt(self, block, idx):


def block_opt(self, block):
block = block.cuda()
named_linears = self.model.get_block_linears(block)
# logger.info(f"named_linears: {named_linears}")
logger.info(f"named_linears: {named_linears}")
input_feat = defaultdict(list)
handles = []
self.block_init(block)
Expand All @@ -65,55 +63,74 @@ def block_opt(self, block, idx):
)

if not self.sparsity_out:
self.input['data'] = self.block_forward(block)
self.input["data"] = self.block_forward(block)
else:
self.block_forward(block)
for h in handles:
h.remove()
torch.cuda.empty_cache()

self.block_transform(block, input_feat, idx, self.input['kwargs'])

self.block_transform(block, input_feat, self.input['kwargs'])
if self.sparsity_out:
self.input['data'] = self.block_forward(block)

block = block.cpu()
del input_feat
gc.collect()
torch.cuda.empty_cache()

def block_transform(self, block, input_feat, idx, block_kwargs):
logger.info(f'Start transform the {idx+1}-th block')


def block_transform(self, block, input_feat, block_kwargs):
logger.info(f"Start transform the {self.block_idx+1}-th block")
subsets = self.model.get_subsets_in_block(block)
for index, subset in enumerate(subsets):
if not self.filter_subset(subset):
continue
# logger.info(f"subset: {subset}")
prev_op = subset['prev_op']
layers_dict = subset['layers']
input_name = subset['input'][0]
inspect_module = subset['inspect']
inspect_has_kwargs = subset['has_kwargs']
prev_op = subset["prev_op"]
layers_dict = subset["layers"]
input_name = subset["input"][0]
inspect_module = subset["inspect"]
inspect_has_kwargs = subset["has_kwargs"]
subset_kwargs = block_kwargs if inspect_has_kwargs else {}
self.subset_transform(
layers_dict,
input_feat,
prev_op,
input_name,
inspect_module,
subset_kwargs,
idx,
subset_kwargs
)
logger.info(f'End transform the {idx+1}-th block')
logger.info(f"End transform the {self.block_idx+1}-th block")


def filter_subset(self, subset):
return True

@torch.no_grad()
def deploy(self, deploy_format):
logger.info(f"-- deploy_sparsity_model start --")
logger.info(f"sparsity_config : {self.sparsity_config}")

logger.info(f"-- deploy_sparsity_model done --")


# todo
@torch.no_grad()
def deploy(self):
logger.info('-- deploy_sparsity_model start --')
logger.info(f'sparsity_config : {self.sparsity_config}')
def copy_tokenizer(self, path):
for substring in self.config.save.get("tokenizer_file_substring", ["token"]):
copy_files(self.config.model.path, path, substring)
logger.info(f"copy tokenizer done --")

# self.model.replace_module_all(module, params_dict)
logger.info('-- deploy_sparsity_model done --')
@torch.no_grad()
def save_model(self, path):
if self.config.model.type == "Llava":
self.model.llava_model.language_model = self.model.get_model()
self.model.llava_model.save_pretrained(path)
logger.info(f"save model done --")
self.copy_tokenizer(path)
copy_files(self.config.model.path, path, "preprocessor_config")
else:
self.model.get_model().save_pretrained(path)
logger.info(f"save model done --")
self.copy_tokenizer(path)
3 changes: 1 addition & 2 deletions llmc/compression/sparsification/magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def subset_transform(
prev_op,
input_name,
inspect_module,
subset_kwargs,
idx,
subset_kwargs
):
layers = list(layers_dict.values())
for layer in layers:
Expand Down
113 changes: 113 additions & 0 deletions llmc/compression/sparsification/shortgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
import torch.nn as nn
from loguru import logger
import gc
from typing import List, Optional
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
from .base_blockwise_sparsification import BaseBlockwiseSparsification
from llmc.utils.registry_factory import ALGO_REGISTRY
import numpy as np
from llmc.utils import copy_files
import json


@ALGO_REGISTRY
class ShortGPT(BaseBlockwiseSparsification):
def __init__(self, model, sparsity_config, input, config):
super().__init__(model, sparsity_config, input, config)

def block_opt(self, block):
block = block.cuda()

output_feat = self.block_forward(block)
torch.cuda.empty_cache()
self.block_transform(self.input["data"], output_feat)
self.input["data"] = output_feat

def block_transform(self, input_feat, output_feat):
logger.info(f"Start transform the {self.block_idx+1}-th block")
self.subset_transform(
input_feat,
output_feat
)

@torch.no_grad()
def compute_bi(
self,
input_feat: torch.Tensor,
output_feat: torch.Tensor
):
_, _, d = input_feat.shape
input_feat = input_feat.reshape(-1, d)
output_feat = output_feat.reshape(-1, d)

norm_input = input_feat.norm(dim=-1, keepdim=True)
norm_output = output_feat.norm(dim=-1, keepdim=True)

sim = (input_feat @ output_feat.T) / (norm_input * norm_output)
sim = sim.diagonal().nan_to_num(nan=0.5)

return 1 - sim

@torch.no_grad()
def subset_transform(
self,
input_feat,
output_feat
):
# caculate BI score
if self.sparser.importances is None:
self.sparser.importances = np.zeros(len(self.blocks))
self.sparser.importances[self.block_idx] = self.compute_bi(input_feat[0], output_feat[0]).sum().cpu().item()


@torch.no_grad()
def remove_layers(
self,
layers_to_remove: Optional[List[int]] = []
):
if not layers_to_remove and self.sparser.n_prune_layers:
layers_to_remove = np.argsort(np.array(self.sparser.importances))[:self.sparser.n_prune_layers].tolist()

for idx in sorted(layers_to_remove, reverse=True):
try:
del self.blocks[idx]
except IndexError:
logger.info(f"layer {idx} does not exist")
return layers_to_remove


@torch.no_grad()
def deploy(self, deploy_format):
logger.info(f"After compute, BI scores are {self.sparser.importances}")


logger.info(f"-- deploy_sparsity_model start --")
logger.info(f"sparsity_config : {self.sparsity_config}")

logger.info(f"-- begin remove layers --")
layers_to_remove = self.remove_layers()
logger.info(f"remove layers: {layers_to_remove}")

logger.info(f"-- deploy_sparsity_model done --")

@torch.no_grad()
def save_model(self, path):
if self.config.model.type == "Llava":
self.model.llava_model.language_model = self.model.get_model()
self.model.llava_model.save_pretrained(path)
logger.info(f"save model done --")
self.copy_tokenizer(path)
copy_files(self.config.model.path, path, "preprocessor_config")
else:
self.model.get_model().save_pretrained(path)
config_file = path + '/config.json'

logger.info(f"save model done --")
self.copy_tokenizer(path)
with open(config_file, 'r') as file:
config_new = json.load(file)
config_new['num_hidden_layers'] = len(self.blocks)
with open(config_file, 'w') as file:
json.dump(config_new, file, indent=4)
12 changes: 8 additions & 4 deletions llmc/compression/sparsification/sparse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
class Sparser:
def __init__(self, sparsity, **kwargs):
self.sparsity = sparsity
self.kwargs = kwargs
self.W_mask = None
def __init__(self, sparsity_constraint, **kwargs):
if 'sparsity' in sparsity_constraint:
self.sparsity = sparsity_constraint["sparsity"]
self.W_mask = None
elif 'n_prune_layers' in sparsity_constraint:
self.n_prune_layers = sparsity_constraint["n_prune_layers"]
self.importances = None
self.kwargs = kwargs
Loading

0 comments on commit 0cd2db8

Please sign in to comment.