diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index b5de0b6f..7fc82821 100644 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -131,7 +131,7 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs) out = out[0] if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]: - org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa + org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device) loss = (org_out - out).float().pow(2).mean().item() @@ -144,13 +144,31 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs) loss_mean += x.shape[0] * 1.0 / n_samples * loss scales_mean += x.shape[0] * 1.0 / n_samples * scales inspect_module.load_state_dict(org_sd) - is_best = loss_mean < best_error - if is_best: - best_error = loss_mean - best_scales = scales_mean - best_scales = best_scales.view(-1) - dist.all_reduce(best_scales, op=dist.ReduceOp.SUM) - best_scales /= int(os.environ['WORLD_SIZE']) + is_best = loss_mean < best_error + if is_best: + best_error = loss_mean + best_scales = scales_mean + + # Synchronize across ranks + best_error_tensor = torch.tensor([best_error], device='cuda') + dist.all_reduce(best_error_tensor, op=dist.ReduceOp.MIN) + global_best_error = best_error_tensor.item() + + # Identify the rank with the minimum loss + global_best_rank = torch.tensor([dist.get_rank() + if best_error == global_best_error + else -1], + device='cuda') + dist.all_reduce(global_best_rank, op=dist.ReduceOp.MAX) + global_best_rank = global_best_rank.item() + + # Broadcast the best scales from the rank with the minimum loss to all ranks + if dist.get_rank() == global_best_rank: + dist.broadcast(best_scales, src=global_best_rank) + else: + best_scales = torch.zeros_like(best_scales, device='cuda') + dist.broadcast(best_scales, src=global_best_rank) + del org_out_dict gc.collect() torch.cuda.empty_cache() @@ -203,13 +221,10 @@ def subset_transform( len(prev_op) in (0, 1) ), 'Only support single prev_op. If multi prev_ops, code need to be updated.' - if len(prev_op) == 0: + if len(prev_op) == 0 or (len(prev_op) == 1 and prev_op[0] is None): logger.info('Cannot apply scale. Do not transform this subset.') return - if 'mlp.experts.0.gate_proj' in list(layers_dict.keys()): - return - if isinstance( prev_op[0], tuple( diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 10d97ce0..c4e763f1 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -17,8 +17,8 @@ from .auto_clip import AutoClipper from .hadamard_utils import apply_exact_had_to_linear, get_hadK from .module_utils import (_LLMC_ATTN_MAP_, _LLMC_LINEAR_TYPES_, - _LLMC_LN_TYPES_, _REALQUANT_LINEAR_MAP_, - _TRANSFORMERS_LINEAR_TYPES_, + _LLMC_LN_TYPES_, _LLMC_MOE_GATE_MAP_, + _REALQUANT_LINEAR_MAP_, _TRANSFORMERS_LINEAR_TYPES_, _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear, FakeQuantLinear, LlmcActFn, OriginFloatLinear, RotateLinear) @@ -63,9 +63,6 @@ def a_qdq(self, act, module, aquantizer, input_index=0): else: return aquantizer.fake_quant_act_dynamic(act) - def logit(self, x): - return torch.log(x / (1 - x)) - def get_replacement_params(self, mode='fake_quant', w_only=False, name=None): params_dict = {} if mode == 'fake_quant': @@ -324,6 +321,22 @@ def replace_attention(self, block, extra_modules): extra_modules.update(matmul_modules) extra_modules.update(softmax_modules) + def replace_moe_gate(self, block): + moe_gate_layer = self.model.get_moe_gate(block) + if moe_gate_layer is not None: + logger.info(moe_gate_layer) + moe_gate_module = _LLMC_MOE_GATE_MAP_[self.config['model']['type']] + layers_dict = {'layers': moe_gate_layer} + self.model.replace_module_subset( + moe_gate_module, + block, + layers_dict, + self.block_idx, + self.get_replacement_params( + mode='quant_moegate', w_only=self.w_only, name=None + ), + ) + @torch.no_grad() def collect_block_qparams(self, block): named_linears = self.model.get_block_linears(block) @@ -367,6 +380,7 @@ def block_forward(self, block, input_data=None): return output def block_opt(self, block): + self.replace_moe_gate(block) block = block.cuda() named_linears = self.model.get_block_linears(block) extra_modules = self.model.get_extra_modules(block) @@ -444,7 +458,6 @@ def block_transform(self, block, input_feat, block_kwargs): if self.act_static: self.register_non_linear_qparams(block, input_feat) - self.register_except_subsets_qparams(block, input_feat) self.set_non_linear_mode('fake_quant', block, False) @@ -487,25 +500,6 @@ def block_transform(self, block, input_feat, block_kwargs): def rehook_next_subset(self, block, subset, next_subset): self.subset_init(next_subset) - - layers_except_subsets = self.model.get_linears_except_subsets(block) - if ( - layers_except_subsets - and not isinstance( - layers_except_subsets[list(layers_except_subsets.keys())[0]], - FakeQuantLinear - ) - ): - self.model.replace_module_subset( - FakeQuantLinear, - block, - {'layers': layers_except_subsets}, - self.block_idx, - self.get_replacement_params( - mode='fake_quant', w_only=self.w_only, name=None - ), - ) - self.model.replace_module_subset( FakeQuantLinear, block, @@ -532,14 +526,6 @@ def collect_layers_weights(self, layers, tensor_parallelize_style=None): weights.append(_m.weight) return weights - @torch.no_grad() - def register_except_subsets_qparams(self, block, input_feat): - layers_dict = self.model.get_linears_except_subsets(block) - for name, layer in layers_dict.items(): - input_tensors = copy.deepcopy(input_feat[name]) - self.register_act_qparams({name: layer}, input_tensors) - del input_tensors - @torch.no_grad() def register_non_linear_qparams(self, block, input_feat): layer_types = [ diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index da12d997..68e332d3 100644 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from loguru import logger from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS @@ -844,7 +845,7 @@ def forward(self, x): @torch.no_grad() def new(cls, module, w_qdq, a_qdq): weight = module.weight.data - if module.bias is not None: + if hasattr(module, 'bias') and module.bias is not None: bias = module.bias.data else: bias = None @@ -946,6 +947,126 @@ def __repr__(self): ) +class LlmcDeepSeekV2MoEGate(nn.Module): + def __init__(self, module): + super().__init__() + self.config = module.config + self.top_k = module.config.num_experts_per_tok + self.n_routed_experts = module.config.n_routed_experts + self.routed_scaling_factor = module.config.routed_scaling_factor + self.scoring_func = module.config.scoring_func + self.alpha = module.config.aux_loss_alpha + self.seq_aux = module.config.seq_aux + self.topk_method = module.config.topk_method + self.n_group = module.config.n_group + self.topk_group = module.config.topk_group + + # topk selection algorithm + self.norm_topk_prob = module.config.norm_topk_prob + self.gating_dim = module.config.hidden_size + self.fc = nn.Linear(self.gating_dim, self.n_routed_experts, bias=False) + self.fc.weight = module.weight + + @property + def weight(self): + return self.fc.weight + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # compute gating score + hidden_states = hidden_states.view(-1, h) + org_dtype = self.fc.weight.dtype + self.fc.weight.data = self.fc.weight.data.to(torch.float32) + logits = self.fc(hidden_states.type(torch.float32)) + self.fc.weight.data = self.fc.weight.data.to(org_dtype) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f'insupportable scoring function for MoE gating: {self.scoring_func}' + ) + + # select top-k experts + if self.topk_method == 'greedy': + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + elif self.topk_method == 'group_limited_greedy': + group_scores = ( + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + # expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + + return topk_idx, topk_weight, aux_loss + + @classmethod + @torch.no_grad() + def new(cls, module): + new_module = cls(module) + new_module.zeros_shape = None + new_module.zeros_dtype = None + return new_module + + def __repr__(self): + return ( + 'LlmcDeepSeekV2MoEGate(' + + f'fc={self.fc})' + ) + + class VllmRealQuantLinear(nn.Module): def __init__(self, weight, bias, scales, input_scale, need_pack): super().__init__() @@ -1308,6 +1429,10 @@ def __repr__(self): 'DeepseekV2': LlmcDeepseekAttention } +_LLMC_MOE_GATE_MAP_ = { + 'DeepseekV2': LlmcDeepSeekV2MoEGate +} + _REALQUANT_LINEAR_MAP_ = { 'vllm_quant': VllmRealQuantLinear, 'sgl_quant': SglRealQuantLinear, diff --git a/llmc/compression/quantization/osplus.py b/llmc/compression/quantization/osplus.py index e4d1da7e..c6003faf 100644 --- a/llmc/compression/quantization/osplus.py +++ b/llmc/compression/quantization/osplus.py @@ -22,12 +22,9 @@ def __init__(self, model, quant_config, input, padding_mask, config, modality='l super().__init__(model, quant_config, input, padding_mask, config, modality) @torch.no_grad() - def filter_subset(self, layers_dict, prev_op): + def filter_subset(self, prev_op): if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)): - if 'mlp.experts.0.gate_proj' in list(layers_dict.keys()): - return False - else: - return True + return True else: return False @@ -196,7 +193,7 @@ def subset_transform( logger.info('Cannot apply scale. Do not transform this subset.') return - if not self.filter_subset(layers_dict, prev_op): + if not self.filter_subset(prev_op): logger.info('Do not transform this subset.') return diff --git a/llmc/compression/quantization/quarot.py b/llmc/compression/quantization/quarot.py index dfdf4cfa..72ddc392 100644 --- a/llmc/compression/quantization/quarot.py +++ b/llmc/compression/quantization/quarot.py @@ -51,7 +51,7 @@ def preprocess(self): self.rotate_head(self.Q) # for vlm model - if hasattr(self.model, 'vlm_model'): + if hasattr(self.model, 'vlm_model') and self.model.vlm_model is not None: logger.info('For vlm model, quarot need rotate last layer in projector.') """ txt_input img_input @@ -116,6 +116,19 @@ def subset_transform(self, block, subset): layers = list(layers_dict.values()) + if 'skip_rotate' in subset and subset['skip_rotate']: + return + + if 'need_rotate_alone' in subset and subset['need_rotate_alone']: + assert isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)) + pre_layers = subset['pre_layers'] + post_layers = subset['post_layers'] + for layer in pre_layers + post_layers: + layer = layer.cuda() + self.fuse_ln_fcs(prev_op[0], layers) + self.rotate_pre_layers(pre_layers, self.Q) + self.rotate_post_layers(post_layers, self.Q) + if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)): self.fuse_ln_fcs(prev_op[0], layers) self.rotate_pre_layers(layers, self.Q) @@ -133,10 +146,11 @@ def subset_transform(self, block, subset): logger.info(f'{self.Q.shape}') self.rotate_post_layers(layers, self.Q, exact_had=False) if self.online_rotate: - apply_exact_had_to_linear( - prev_op[0], had_dim=self.head_dim, output=True - ) - apply_exact_had_to_linear(layers[0], had_dim=-1, output=False) + if prev_op[0] is not None: + apply_exact_had_to_linear( + prev_op[0], had_dim=self.head_dim, output=True + ) + apply_exact_had_to_linear(layers[0], had_dim=-1, output=False) @torch.no_grad() def save_model(self, path): diff --git a/llmc/compression/quantization/smoothquant.py b/llmc/compression/quantization/smoothquant.py index de0c4710..3c47b4bc 100644 --- a/llmc/compression/quantization/smoothquant.py +++ b/llmc/compression/quantization/smoothquant.py @@ -18,12 +18,9 @@ def __init__(self, model, quant_config, input, padding_mask, config, modality='l self.alpha = special_config.get('alpha', 0.5) @torch.no_grad() - def filter_subset(self, layers_dict, prev_op): + def filter_subset(self, prev_op): if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)): - if 'mlp.experts.0.gate_proj' in list(layers_dict.keys()): - return False - else: - return True + return True else: return False @@ -71,7 +68,7 @@ def subset_transform( inspect_module, subset_kwargs, ): - if not self.filter_subset(layers_dict, prev_op): + if not self.filter_subset(prev_op): logger.info('Do not transform this subset.') return layers = list(layers_dict.values()) diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 7e0fb5c2..741e5d35 100644 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -69,9 +69,6 @@ def get_act_fn_in_block(self): def get_softmax_in_block(self): return {} - def get_linears_except_subsets(self, block): - return {} - @abstractmethod def get_subsets_in_block(self, block): pass @@ -256,6 +253,9 @@ def get_all_linears(self, module): def get_extra_modules(self, block): return {} + def get_moe_gate(self, block): + return None + def set_mix_bits_params_dict(self, block_idx, name, params_dict): logger.info('set_mix_bits_params_dict') diff --git a/llmc/models/deepseekv2.py b/llmc/models/deepseekv2.py index 9d174a2f..babdb3b2 100644 --- a/llmc/models/deepseekv2.py +++ b/llmc/models/deepseekv2.py @@ -34,11 +34,6 @@ def skip_layer_name(self): def has_bias(self): return False - def get_linears_except_subsets(self, block): - return { - 'self_attn.o_proj': block.self_attn.o_proj - } - def get_layernorms_in_block(self, block): return { 'input_layernorm': block.input_layernorm, @@ -57,13 +52,21 @@ def get_matmul_in_block(self, block): def get_softmax_in_block(self, block): return {'self_attn.softmax': block.self_attn.softmax} - def get_subsets_in_block(self, block): + def get_head_layers(self): + return [self.model.lm_head] - layers = [] + def get_pre_head_layernorm_layers(self): + return [self.model.model.norm] - # attn input - if hasattr(block.self_attn, 'q_proj'): + def get_moe_gate(self, block): + if hasattr(block.mlp, 'gate'): + return {'mlp.gate': block.mlp.gate} + else: + return None + def get_subsets_in_block(self, block): + layers = [] + if hasattr(block.self_attn, 'q_proj'): layers.append( { 'layers': { @@ -96,9 +99,21 @@ def get_subsets_in_block(self, block): 'input': ['self_attn.q_b_proj'], 'inspect': block.self_attn.q_b_proj, 'has_kwargs': False, + 'need_rotate_alone': True, + 'pre_layers': [block.self_attn.q_a_proj], + 'post_layers': [block.self_attn.q_b_proj] } ) + layers.append( + { + 'layers': {'self_attn.o_proj': block.self_attn.o_proj}, + 'prev_op': [None], + 'input': ['self_attn.o_proj'], + 'inspect': block.self_attn.o_proj, + 'has_kwargs': False, + }, + ) layers.append( { 'layers': {'self_attn.kv_b_proj': block.self_attn.kv_b_proj}, @@ -106,6 +121,7 @@ def get_subsets_in_block(self, block): 'input': ['self_attn.kv_b_proj'], 'inspect': block.self_attn.kv_b_proj, 'has_kwargs': False, + 'skip_rotate': True } ) @@ -119,6 +135,7 @@ def get_subsets_in_block(self, block): for i in range(len(block.mlp.experts))}, 'mlp.shared_experts.gate_proj': block.mlp.shared_experts.gate_proj, # noqa 'mlp.shared_experts.up_proj': block.mlp.shared_experts.up_proj, + 'mlp.gate.fc': block.mlp.gate.fc, }, 'prev_op': [block.post_attention_layernorm], 'input': ['mlp'],