Skip to content

Commit

Permalink
add do_trans in config & remove language catcher & support chaglm (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang authored Nov 21, 2024
1 parent 135fe9c commit f4fb949
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 47 deletions.
5 changes: 5 additions & 0 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def subset_transform(
prev_op = subset['prev_op']
input_name = subset['input'][0]
inspect_module = subset['inspect']
do_trans = subset.get('do_trans', True)
if not do_trans:
logger.info('do_trans is set to False. Do not transform this subset.')
return

if not check_do_quant(
self.block_idx,
Expand Down Expand Up @@ -241,6 +245,7 @@ def subset_transform(
if (
isinstance(prev_op[0], (nn.Linear, FakeQuantLinear))
and prev_op[0].out_features != layers[0].in_features * 3
and prev_op[0].out_features != layers[0].in_features * 2
and prev_op[0].out_features != layers[0].in_features
):
logger.info('Cannot apply scale. Do not transform this subset.')
Expand Down
13 changes: 12 additions & 1 deletion llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ def apply_shift(self, shifts, prev_op, layers):
def scale_fc_fc(self, fc1, fc2, scales):
scales = scales.to(fc1.weight.device)
if fc1.out_features == fc2.in_features * 3:
logger.info('fc1.out_features == fc2.in_features * 3')
num_heads = self.model.get_num_attention_heads()
fc1.weight.t_()
org_shape = fc1.weight.shape
Expand All @@ -616,13 +617,23 @@ def scale_fc_fc(self, fc1, fc2, scales):
fc1.bias[:, 2, :].shape
)
fc1.bias.data = fc1.bias.data.reshape(-1)
else:
elif fc1.out_features == fc2.in_features * 2:
logger.info('fc1.out_features == fc2.in_features * 2')
fc1.weight.data[fc1.weight.data.shape[0] // 2:].div_(scales.view(-1, 1))
if hasattr(fc1, 'bias') and fc1.bias is not None:
fc1.bias.data[fc1.bias.data.shape[0] // 2:].div_(scales.view(-1))
elif fc1.out_features == fc2.in_features:
logger.info('fc1.out_features == fc2.in_features')
assert fc1.out_features == fc2.in_features

if hasattr(fc1, 'bias') and fc1.bias is not None:
fc1.bias.div_(scales.view(-1))

fc1.weight.div_(scales.view(-1, 1))
else:
logger.error(f'fc1.out_features: {fc1.out_features}')
logger.error(f'fc2.in_features: {fc2.in_features}')
raise Exception('Can not scale this fc-fc.')

fc2.weight.mul_(scales.view(1, -1))

Expand Down
1 change: 1 addition & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bloom import Bloom
from .chatglm import ChatGLM
from .deepseekv2 import DeepseekV2
from .falcon import Falcon
from .gemma2 import Gemma2
Expand Down
26 changes: 2 additions & 24 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def get_attention_rotary_layers(self):
def batch_process(self):
raise Exception('batch_process should not be called here.')

def get_vision_catcher(self, first_block_input):

def get_catcher(self, first_block_input):
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
Expand All @@ -125,24 +124,6 @@ def forward(self, *args, **kwargs):
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError

return Catcher

def get_language_catcher(self, first_block_input):

class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, inp, **kwargs):
first_block_input['data'].append(inp)
if 'output_router_logits' in kwargs:
assert kwargs['output_router_logits'] is False
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError

return Catcher

def __str__(self):
Expand Down Expand Up @@ -184,10 +165,7 @@ def collect_first_block_input(self, calib_data, padding_mask=None,
first_block_input = defaultdict(list)

self.find_blocks(modality)
if modality == 'language':
Catcher = self.get_language_catcher(first_block_input)
elif modality == 'vision':
Catcher = self.get_vision_catcher(first_block_input)
Catcher = self.get_catcher(first_block_input)

self.move_embed_to_device('cuda')
if data_type == 'img_txt':
Expand Down
88 changes: 88 additions & 0 deletions llmc/models/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import inspect

import torch.nn as nn

from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel


@MODEL_REGISTRY
class ChatGLM(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
self.blocks = self.model.transformer.encoder.layers

def find_embed_layers(self):
self.embedding = self.model.transformer.embedding
self.rotary_pos_emb = self.model.transformer.rotary_pos_emb

def find_block_name(self):
self.block_name_prefix = 'transformer.encoder.layers'

def get_embed_layers(self):
return [self.embedding]

def get_attention_rotary_layers(self):
return [self.rotary_pos_emb]

def get_head_layers(self):
return [self.model.transformer.output_layer]

def get_pre_head_layernorm_layers(self):
return [self.model.transformer.encoder.final_layernorm]

def get_layers_except_blocks(self):
return [self.embedding, self.rotary_pos_emb, self.model.transformer.output_layer, self.model.transformer.encoder.final_layernorm] # noqa

def skip_layer_name(self):
return ['final_layernorm']

def has_bias(self):
return False

def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
}

def get_subsets_in_block(self, block):
return [
{
'layers': {
'self_attention.query_key_value': block.self_attention.query_key_value
},
'prev_op': [block.input_layernorm],
'input': ['self_attention.query_key_value'],
'inspect': block.self_attention,
'has_kwargs': True,
},
{
'layers': {'self_attention.dense': block.self_attention.dense},
'prev_op': [block.self_attention.query_key_value],
'input': ['self_attention.dense'],
'inspect': block.self_attention.dense,
'has_kwargs': False,
},
{
'layers': {
'mlp.dense_h_to_4h': block.mlp.dense_h_to_4h
},
'prev_op': [block.post_attention_layernorm],
'input': ['mlp.dense_h_to_4h'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.down_proj': block.mlp.dense_4h_to_h},
'prev_op': [block.mlp.dense_h_to_4h],
'input': ['mlp.dense_4h_to_h'],
'inspect': block.mlp.dense_4h_to_h,
'has_kwargs': False,
'is_mlp': True,
},
]
1 change: 1 addition & 0 deletions llmc/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,6 @@ def get_subsets_in_block(self, block):
'inspect': block.fc2,
'has_kwargs': False,
'is_mlp': True,
'do_trans': False
},
]
22 changes: 0 additions & 22 deletions llmc/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,6 @@ def batch_process(self, imgs):
samples.append(sample)
return samples

def get_catcher(self, first_block_input):

class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.signature = inspect.signature(module.forward)

def forward(self, *args, **kwargs):
params = list(self.signature.parameters.keys())
for i, arg in enumerate(args):
if i > 0:
kwargs[params[i]] = arg
first_block_input['data'].append(args[0])
if 'output_router_logits' in kwargs:
assert kwargs['output_router_logits'] is False
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError

return Catcher

def get_subsets_in_block(self, block):
return [
{
Expand Down

0 comments on commit f4fb949

Please sign in to comment.