Skip to content

Commit

Permalink
Merge pull request #204 from ModelTC/dev_fixbug
Browse files Browse the repository at this point in the history
Support deepseekv2 quarot
  • Loading branch information
gushiqiao authored Nov 19, 2024
2 parents 4f8ae1f + e0c225f commit a6c95fc
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 75 deletions.
39 changes: 27 additions & 12 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
out = out[0]

if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]:
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device)

loss = (org_out - out).float().pow(2).mean().item()
Expand All @@ -144,13 +144,31 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
loss_mean += x.shape[0] * 1.0 / n_samples * loss
scales_mean += x.shape[0] * 1.0 / n_samples * scales
inspect_module.load_state_dict(org_sd)
is_best = loss_mean < best_error
if is_best:
best_error = loss_mean
best_scales = scales_mean
best_scales = best_scales.view(-1)
dist.all_reduce(best_scales, op=dist.ReduceOp.SUM)
best_scales /= int(os.environ['WORLD_SIZE'])
is_best = loss_mean < best_error
if is_best:
best_error = loss_mean
best_scales = scales_mean

# Synchronize across ranks
best_error_tensor = torch.tensor([best_error], device='cuda')
dist.all_reduce(best_error_tensor, op=dist.ReduceOp.MIN)
global_best_error = best_error_tensor.item()

# Identify the rank with the minimum loss
global_best_rank = torch.tensor([dist.get_rank()
if best_error == global_best_error
else -1],
device='cuda')
dist.all_reduce(global_best_rank, op=dist.ReduceOp.MAX)
global_best_rank = global_best_rank.item()

# Broadcast the best scales from the rank with the minimum loss to all ranks
if dist.get_rank() == global_best_rank:
dist.broadcast(best_scales, src=global_best_rank)
else:
best_scales = torch.zeros_like(best_scales, device='cuda')
dist.broadcast(best_scales, src=global_best_rank)

del org_out_dict
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -203,13 +221,10 @@ def subset_transform(
len(prev_op) in (0, 1)
), 'Only support single prev_op. If multi prev_ops, code need to be updated.'

if len(prev_op) == 0:
if len(prev_op) == 0 or (len(prev_op) == 1 and prev_op[0] is None):
logger.info('Cannot apply scale. Do not transform this subset.')
return

if 'mlp.experts.0.gate_proj' in list(layers_dict.keys()):
return

if isinstance(
prev_op[0],
tuple(
Expand Down
52 changes: 19 additions & 33 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from .auto_clip import AutoClipper
from .hadamard_utils import apply_exact_had_to_linear, get_hadK
from .module_utils import (_LLMC_ATTN_MAP_, _LLMC_LINEAR_TYPES_,
_LLMC_LN_TYPES_, _REALQUANT_LINEAR_MAP_,
_TRANSFORMERS_LINEAR_TYPES_,
_LLMC_LN_TYPES_, _LLMC_MOE_GATE_MAP_,
_REALQUANT_LINEAR_MAP_, _TRANSFORMERS_LINEAR_TYPES_,
_TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear,
FakeQuantLinear, LlmcActFn, OriginFloatLinear,
RotateLinear)
Expand Down Expand Up @@ -63,9 +63,6 @@ def a_qdq(self, act, module, aquantizer, input_index=0):
else:
return aquantizer.fake_quant_act_dynamic(act)

def logit(self, x):
return torch.log(x / (1 - x))

def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
params_dict = {}
if mode == 'fake_quant':
Expand Down Expand Up @@ -324,6 +321,22 @@ def replace_attention(self, block, extra_modules):
extra_modules.update(matmul_modules)
extra_modules.update(softmax_modules)

def replace_moe_gate(self, block):
moe_gate_layer = self.model.get_moe_gate(block)
if moe_gate_layer is not None:
logger.info(moe_gate_layer)
moe_gate_module = _LLMC_MOE_GATE_MAP_[self.config['model']['type']]
layers_dict = {'layers': moe_gate_layer}
self.model.replace_module_subset(
moe_gate_module,
block,
layers_dict,
self.block_idx,
self.get_replacement_params(
mode='quant_moegate', w_only=self.w_only, name=None
),
)

@torch.no_grad()
def collect_block_qparams(self, block):
named_linears = self.model.get_block_linears(block)
Expand Down Expand Up @@ -367,6 +380,7 @@ def block_forward(self, block, input_data=None):
return output

def block_opt(self, block):
self.replace_moe_gate(block)
block = block.cuda()
named_linears = self.model.get_block_linears(block)
extra_modules = self.model.get_extra_modules(block)
Expand Down Expand Up @@ -444,7 +458,6 @@ def block_transform(self, block, input_feat, block_kwargs):

if self.act_static:
self.register_non_linear_qparams(block, input_feat)
self.register_except_subsets_qparams(block, input_feat)

self.set_non_linear_mode('fake_quant', block, False)

Expand Down Expand Up @@ -487,25 +500,6 @@ def block_transform(self, block, input_feat, block_kwargs):

def rehook_next_subset(self, block, subset, next_subset):
self.subset_init(next_subset)

layers_except_subsets = self.model.get_linears_except_subsets(block)
if (
layers_except_subsets
and not isinstance(
layers_except_subsets[list(layers_except_subsets.keys())[0]],
FakeQuantLinear
)
):
self.model.replace_module_subset(
FakeQuantLinear,
block,
{'layers': layers_except_subsets},
self.block_idx,
self.get_replacement_params(
mode='fake_quant', w_only=self.w_only, name=None
),
)

self.model.replace_module_subset(
FakeQuantLinear,
block,
Expand All @@ -532,14 +526,6 @@ def collect_layers_weights(self, layers, tensor_parallelize_style=None):
weights.append(_m.weight)
return weights

@torch.no_grad()
def register_except_subsets_qparams(self, block, input_feat):
layers_dict = self.model.get_linears_except_subsets(block)
for name, layer in layers_dict.items():
input_tensors = copy.deepcopy(input_feat[name])
self.register_act_qparams({name: layer}, input_tensors)
del input_tensors

@torch.no_grad()
def register_non_linear_qparams(self, block, input_feat):
layer_types = [
Expand Down
127 changes: 126 additions & 1 deletion llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

Expand Down Expand Up @@ -844,7 +845,7 @@ def forward(self, x):
@torch.no_grad()
def new(cls, module, w_qdq, a_qdq):
weight = module.weight.data
if module.bias is not None:
if hasattr(module, 'bias') and module.bias is not None:
bias = module.bias.data
else:
bias = None
Expand Down Expand Up @@ -946,6 +947,126 @@ def __repr__(self):
)


class LlmcDeepSeekV2MoEGate(nn.Module):
def __init__(self, module):
super().__init__()
self.config = module.config
self.top_k = module.config.num_experts_per_tok
self.n_routed_experts = module.config.n_routed_experts
self.routed_scaling_factor = module.config.routed_scaling_factor
self.scoring_func = module.config.scoring_func
self.alpha = module.config.aux_loss_alpha
self.seq_aux = module.config.seq_aux
self.topk_method = module.config.topk_method
self.n_group = module.config.n_group
self.topk_group = module.config.topk_group

# topk selection algorithm
self.norm_topk_prob = module.config.norm_topk_prob
self.gating_dim = module.config.hidden_size
self.fc = nn.Linear(self.gating_dim, self.n_routed_experts, bias=False)
self.fc.weight = module.weight

@property
def weight(self):
return self.fc.weight

def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# compute gating score
hidden_states = hidden_states.view(-1, h)
org_dtype = self.fc.weight.dtype
self.fc.weight.data = self.fc.weight.data.to(torch.float32)
logits = self.fc(hidden_states.type(torch.float32))
self.fc.weight.data = self.fc.weight.data.to(org_dtype)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1, dtype=torch.float32)
else:
raise NotImplementedError(
f'insupportable scoring function for MoE gating: {self.scoring_func}'
)

# select top-k experts
if self.topk_method == 'greedy':
topk_weight, topk_idx = torch.topk(
scores, k=self.top_k, dim=-1, sorted=False
)
elif self.topk_method == 'group_limited_greedy':
group_scores = (
scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(
group_scores, k=self.topk_group, dim=-1, sorted=False
)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(
bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
)
.reshape(bsz * seq_len, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weight, topk_idx = torch.topk(
tmp_scores, k=self.top_k, dim=-1, sorted=False
)

# norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
else:
topk_weight = topk_weight * self.routed_scaling_factor
# expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(
bsz, self.n_routed_experts, device=hidden_states.device
)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
dim=1
).mean() * self.alpha
else:
mask_ce = F.one_hot(
topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None

return topk_idx, topk_weight, aux_loss

@classmethod
@torch.no_grad()
def new(cls, module):
new_module = cls(module)
new_module.zeros_shape = None
new_module.zeros_dtype = None
return new_module

def __repr__(self):
return (
'LlmcDeepSeekV2MoEGate('
+ f'fc={self.fc})'
)


class VllmRealQuantLinear(nn.Module):
def __init__(self, weight, bias, scales, input_scale, need_pack):
super().__init__()
Expand Down Expand Up @@ -1308,6 +1429,10 @@ def __repr__(self):
'DeepseekV2': LlmcDeepseekAttention
}

_LLMC_MOE_GATE_MAP_ = {
'DeepseekV2': LlmcDeepSeekV2MoEGate
}

_REALQUANT_LINEAR_MAP_ = {
'vllm_quant': VllmRealQuantLinear,
'sgl_quant': SglRealQuantLinear,
Expand Down
9 changes: 3 additions & 6 deletions llmc/compression/quantization/osplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ def __init__(self, model, quant_config, input, padding_mask, config, modality='l
super().__init__(model, quant_config, input, padding_mask, config, modality)

@torch.no_grad()
def filter_subset(self, layers_dict, prev_op):
def filter_subset(self, prev_op):
if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
if 'mlp.experts.0.gate_proj' in list(layers_dict.keys()):
return False
else:
return True
return True
else:
return False

Expand Down Expand Up @@ -196,7 +193,7 @@ def subset_transform(
logger.info('Cannot apply scale. Do not transform this subset.')
return

if not self.filter_subset(layers_dict, prev_op):
if not self.filter_subset(prev_op):
logger.info('Do not transform this subset.')
return

Expand Down
24 changes: 19 additions & 5 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def preprocess(self):
self.rotate_head(self.Q)

# for vlm model
if hasattr(self.model, 'vlm_model'):
if hasattr(self.model, 'vlm_model') and self.model.vlm_model is not None:
logger.info('For vlm model, quarot need rotate last layer in projector.')
"""
txt_input img_input
Expand Down Expand Up @@ -116,6 +116,19 @@ def subset_transform(self, block, subset):

layers = list(layers_dict.values())

if 'skip_rotate' in subset and subset['skip_rotate']:
return

if 'need_rotate_alone' in subset and subset['need_rotate_alone']:
assert isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_))
pre_layers = subset['pre_layers']
post_layers = subset['post_layers']
for layer in pre_layers + post_layers:
layer = layer.cuda()
self.fuse_ln_fcs(prev_op[0], layers)
self.rotate_pre_layers(pre_layers, self.Q)
self.rotate_post_layers(post_layers, self.Q)

if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
self.fuse_ln_fcs(prev_op[0], layers)
self.rotate_pre_layers(layers, self.Q)
Expand All @@ -133,10 +146,11 @@ def subset_transform(self, block, subset):
logger.info(f'{self.Q.shape}')
self.rotate_post_layers(layers, self.Q, exact_had=False)
if self.online_rotate:
apply_exact_had_to_linear(
prev_op[0], had_dim=self.head_dim, output=True
)
apply_exact_had_to_linear(layers[0], had_dim=-1, output=False)
if prev_op[0] is not None:
apply_exact_had_to_linear(
prev_op[0], had_dim=self.head_dim, output=True
)
apply_exact_had_to_linear(layers[0], had_dim=-1, output=False)

@torch.no_grad()
def save_model(self, path):
Expand Down
Loading

0 comments on commit a6c95fc

Please sign in to comment.