From 090939330390991d97a9d4554ee043ecd69288cd Mon Sep 17 00:00:00 2001 From: helloyongyang Date: Thu, 21 Nov 2024 21:42:29 +0800 Subject: [PATCH] add do_trans in config & remove language catcher & support chaglm --- llmc/compression/quantization/awq.py | 5 ++ .../base_blockwise_quantization.py | 13 ++- llmc/models/__init__.py | 1 + llmc/models/base_model.py | 26 +----- llmc/models/chatglm.py | 88 +++++++++++++++++++ llmc/models/opt.py | 1 + llmc/models/vit.py | 22 ----- 7 files changed, 109 insertions(+), 47 deletions(-) create mode 100644 llmc/models/chatglm.py diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index 3f0c0d1c..4633afcd 100644 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -203,6 +203,10 @@ def subset_transform( prev_op = subset['prev_op'] input_name = subset['input'][0] inspect_module = subset['inspect'] + do_trans = subset.get('do_trans', True) + if not do_trans: + logger.info('do_trans is set to False. Do not transform this subset.') + return if not check_do_quant( self.block_idx, @@ -241,6 +245,7 @@ def subset_transform( if ( isinstance(prev_op[0], (nn.Linear, FakeQuantLinear)) and prev_op[0].out_features != layers[0].in_features * 3 + and prev_op[0].out_features != layers[0].in_features * 2 and prev_op[0].out_features != layers[0].in_features ): logger.info('Cannot apply scale. Do not transform this subset.') diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 5178de8c..c8dfe94b 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -598,6 +598,7 @@ def apply_shift(self, shifts, prev_op, layers): def scale_fc_fc(self, fc1, fc2, scales): scales = scales.to(fc1.weight.device) if fc1.out_features == fc2.in_features * 3: + logger.info('fc1.out_features == fc2.in_features * 3') num_heads = self.model.get_num_attention_heads() fc1.weight.t_() org_shape = fc1.weight.shape @@ -616,13 +617,23 @@ def scale_fc_fc(self, fc1, fc2, scales): fc1.bias[:, 2, :].shape ) fc1.bias.data = fc1.bias.data.reshape(-1) - else: + elif fc1.out_features == fc2.in_features * 2: + logger.info('fc1.out_features == fc2.in_features * 2') + fc1.weight.data[fc1.weight.data.shape[0] // 2:].div_(scales.view(-1, 1)) + if hasattr(fc1, 'bias') and fc1.bias is not None: + fc1.bias.data[fc1.bias.data.shape[0] // 2:].div_(scales.view(-1)) + elif fc1.out_features == fc2.in_features: + logger.info('fc1.out_features == fc2.in_features') assert fc1.out_features == fc2.in_features if hasattr(fc1, 'bias') and fc1.bias is not None: fc1.bias.div_(scales.view(-1)) fc1.weight.div_(scales.view(-1, 1)) + else: + logger.error(f'fc1.out_features: {fc1.out_features}') + logger.error(f'fc2.in_features: {fc2.in_features}') + raise Exception('Can not scale this fc-fc.') fc2.weight.mul_(scales.view(1, -1)) diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 276a443a..e3da94a0 100644 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -1,4 +1,5 @@ from .bloom import Bloom +from .chatglm import ChatGLM from .deepseekv2 import DeepseekV2 from .falcon import Falcon from .gemma2 import Gemma2 diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 9fd6cfb9..5c7ed081 100644 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -106,8 +106,7 @@ def get_attention_rotary_layers(self): def batch_process(self): raise Exception('batch_process should not be called here.') - def get_vision_catcher(self, first_block_input): - + def get_catcher(self, first_block_input): class Catcher(nn.Module): def __init__(self, module): super().__init__() @@ -125,24 +124,6 @@ def forward(self, *args, **kwargs): kwargs.pop('output_router_logits') first_block_input['kwargs'].append(kwargs) raise ValueError - - return Catcher - - def get_language_catcher(self, first_block_input): - - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - first_block_input['data'].append(inp) - if 'output_router_logits' in kwargs: - assert kwargs['output_router_logits'] is False - kwargs.pop('output_router_logits') - first_block_input['kwargs'].append(kwargs) - raise ValueError - return Catcher def __str__(self): @@ -184,10 +165,7 @@ def collect_first_block_input(self, calib_data, padding_mask=None, first_block_input = defaultdict(list) self.find_blocks(modality) - if modality == 'language': - Catcher = self.get_language_catcher(first_block_input) - elif modality == 'vision': - Catcher = self.get_vision_catcher(first_block_input) + Catcher = self.get_catcher(first_block_input) self.move_embed_to_device('cuda') if data_type == 'img_txt': diff --git a/llmc/models/chatglm.py b/llmc/models/chatglm.py new file mode 100644 index 00000000..f3606226 --- /dev/null +++ b/llmc/models/chatglm.py @@ -0,0 +1,88 @@ +import inspect + +import torch.nn as nn + +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class ChatGLM(BaseModel): + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + + def find_blocks(self, modality='language'): + self.blocks = self.model.transformer.encoder.layers + + def find_embed_layers(self): + self.embedding = self.model.transformer.embedding + self.rotary_pos_emb = self.model.transformer.rotary_pos_emb + + def find_block_name(self): + self.block_name_prefix = 'transformer.encoder.layers' + + def get_embed_layers(self): + return [self.embedding] + + def get_attention_rotary_layers(self): + return [self.rotary_pos_emb] + + def get_head_layers(self): + return [self.model.transformer.output_layer] + + def get_pre_head_layernorm_layers(self): + return [self.model.transformer.encoder.final_layernorm] + + def get_layers_except_blocks(self): + return [self.embedding, self.rotary_pos_emb, self.model.transformer.output_layer, self.model.transformer.encoder.final_layernorm] # noqa + + def skip_layer_name(self): + return ['final_layernorm'] + + def has_bias(self): + return False + + def get_layernorms_in_block(self, block): + return { + 'input_layernorm': block.input_layernorm, + 'post_attention_layernorm': block.post_attention_layernorm, + } + + def get_subsets_in_block(self, block): + return [ + { + 'layers': { + 'self_attention.query_key_value': block.self_attention.query_key_value + }, + 'prev_op': [block.input_layernorm], + 'input': ['self_attention.query_key_value'], + 'inspect': block.self_attention, + 'has_kwargs': True, + }, + { + 'layers': {'self_attention.dense': block.self_attention.dense}, + 'prev_op': [block.self_attention.query_key_value], + 'input': ['self_attention.dense'], + 'inspect': block.self_attention.dense, + 'has_kwargs': False, + }, + { + 'layers': { + 'mlp.dense_h_to_4h': block.mlp.dense_h_to_4h + }, + 'prev_op': [block.post_attention_layernorm], + 'input': ['mlp.dense_h_to_4h'], + 'inspect': block.mlp, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.down_proj': block.mlp.dense_4h_to_h}, + 'prev_op': [block.mlp.dense_h_to_4h], + 'input': ['mlp.dense_4h_to_h'], + 'inspect': block.mlp.dense_4h_to_h, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] diff --git a/llmc/models/opt.py b/llmc/models/opt.py index 0c5dfe94..a0c0e5cc 100644 --- a/llmc/models/opt.py +++ b/llmc/models/opt.py @@ -85,5 +85,6 @@ def get_subsets_in_block(self, block): 'inspect': block.fc2, 'has_kwargs': False, 'is_mlp': True, + 'do_trans': False }, ] diff --git a/llmc/models/vit.py b/llmc/models/vit.py index 4de10e11..8f7e441b 100644 --- a/llmc/models/vit.py +++ b/llmc/models/vit.py @@ -78,28 +78,6 @@ def batch_process(self, imgs): samples.append(sample) return samples - def get_catcher(self, first_block_input): - - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - self.signature = inspect.signature(module.forward) - - def forward(self, *args, **kwargs): - params = list(self.signature.parameters.keys()) - for i, arg in enumerate(args): - if i > 0: - kwargs[params[i]] = arg - first_block_input['data'].append(args[0]) - if 'output_router_logits' in kwargs: - assert kwargs['output_router_logits'] is False - kwargs.pop('output_router_logits') - first_block_input['kwargs'].append(kwargs) - raise ValueError - - return Catcher - def get_subsets_in_block(self, block): return [ {