Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support deepseekv2 quarot #204

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
out = out[0]

if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]:
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device)

loss = (org_out - out).float().pow(2).mean().item()
Expand All @@ -144,13 +144,31 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
loss_mean += x.shape[0] * 1.0 / n_samples * loss
scales_mean += x.shape[0] * 1.0 / n_samples * scales
inspect_module.load_state_dict(org_sd)
is_best = loss_mean < best_error
if is_best:
best_error = loss_mean
best_scales = scales_mean
best_scales = best_scales.view(-1)
dist.all_reduce(best_scales, op=dist.ReduceOp.SUM)
best_scales /= int(os.environ['WORLD_SIZE'])
is_best = loss_mean < best_error
if is_best:
best_error = loss_mean
best_scales = scales_mean

# Synchronize across ranks
best_error_tensor = torch.tensor([best_error], device='cuda')
dist.all_reduce(best_error_tensor, op=dist.ReduceOp.MIN)
global_best_error = best_error_tensor.item()

# Identify the rank with the minimum loss
global_best_rank = torch.tensor([dist.get_rank()
if best_error == global_best_error
else -1],
device='cuda')
dist.all_reduce(global_best_rank, op=dist.ReduceOp.MAX)
global_best_rank = global_best_rank.item()

# Broadcast the best scales from the rank with the minimum loss to all ranks
if dist.get_rank() == global_best_rank:
dist.broadcast(best_scales, src=global_best_rank)
else:
best_scales = torch.zeros_like(best_scales, device='cuda')
dist.broadcast(best_scales, src=global_best_rank)

del org_out_dict
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -203,13 +221,10 @@ def subset_transform(
len(prev_op) in (0, 1)
), 'Only support single prev_op. If multi prev_ops, code need to be updated.'

if len(prev_op) == 0:
if len(prev_op) == 0 or (len(prev_op) == 1 and prev_op[0] is None):
logger.info('Cannot apply scale. Do not transform this subset.')
return

if 'mlp.experts.0.gate_proj' in list(layers_dict.keys()):
return

if isinstance(
prev_op[0],
tuple(
Expand Down
52 changes: 19 additions & 33 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from .auto_clip import AutoClipper
from .hadamard_utils import apply_exact_had_to_linear, get_hadK
from .module_utils import (_LLMC_ATTN_MAP_, _LLMC_LINEAR_TYPES_,
_LLMC_LN_TYPES_, _REALQUANT_LINEAR_MAP_,
_TRANSFORMERS_LINEAR_TYPES_,
_LLMC_LN_TYPES_, _LLMC_MOE_GATE_MAP_,
_REALQUANT_LINEAR_MAP_, _TRANSFORMERS_LINEAR_TYPES_,
_TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear,
FakeQuantLinear, LlmcActFn, OriginFloatLinear,
RotateLinear)
Expand Down Expand Up @@ -63,9 +63,6 @@ def a_qdq(self, act, module, aquantizer, input_index=0):
else:
return aquantizer.fake_quant_act_dynamic(act)

def logit(self, x):
return torch.log(x / (1 - x))

def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
params_dict = {}
if mode == 'fake_quant':
Expand Down Expand Up @@ -324,6 +321,22 @@ def replace_attention(self, block, extra_modules):
extra_modules.update(matmul_modules)
extra_modules.update(softmax_modules)

def replace_moe_gate(self, block):
moe_gate_layer = self.model.get_moe_gate(block)
if moe_gate_layer is not None:
logger.info(moe_gate_layer)
moe_gate_module = _LLMC_MOE_GATE_MAP_[self.config['model']['type']]
layers_dict = {'layers': moe_gate_layer}
self.model.replace_module_subset(
moe_gate_module,
block,
layers_dict,
self.block_idx,
self.get_replacement_params(
mode='quant_moegate', w_only=self.w_only, name=None
),
)

@torch.no_grad()
def collect_block_qparams(self, block):
named_linears = self.model.get_block_linears(block)
Expand Down Expand Up @@ -367,6 +380,7 @@ def block_forward(self, block, input_data=None):
return output

def block_opt(self, block):
self.replace_moe_gate(block)
block = block.cuda()
named_linears = self.model.get_block_linears(block)
extra_modules = self.model.get_extra_modules(block)
Expand Down Expand Up @@ -444,7 +458,6 @@ def block_transform(self, block, input_feat, block_kwargs):

if self.act_static:
self.register_non_linear_qparams(block, input_feat)
self.register_except_subsets_qparams(block, input_feat)

self.set_non_linear_mode('fake_quant', block, False)

Expand Down Expand Up @@ -487,25 +500,6 @@ def block_transform(self, block, input_feat, block_kwargs):

def rehook_next_subset(self, block, subset, next_subset):
self.subset_init(next_subset)

layers_except_subsets = self.model.get_linears_except_subsets(block)
if (
layers_except_subsets
and not isinstance(
layers_except_subsets[list(layers_except_subsets.keys())[0]],
FakeQuantLinear
)
):
self.model.replace_module_subset(
FakeQuantLinear,
block,
{'layers': layers_except_subsets},
self.block_idx,
self.get_replacement_params(
mode='fake_quant', w_only=self.w_only, name=None
),
)

self.model.replace_module_subset(
FakeQuantLinear,
block,
Expand All @@ -532,14 +526,6 @@ def collect_layers_weights(self, layers, tensor_parallelize_style=None):
weights.append(_m.weight)
return weights

@torch.no_grad()
def register_except_subsets_qparams(self, block, input_feat):
layers_dict = self.model.get_linears_except_subsets(block)
for name, layer in layers_dict.items():
input_tensors = copy.deepcopy(input_feat[name])
self.register_act_qparams({name: layer}, input_tensors)
del input_tensors

@torch.no_grad()
def register_non_linear_qparams(self, block, input_feat):
layer_types = [
Expand Down
127 changes: 126 additions & 1 deletion llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

Expand Down Expand Up @@ -844,7 +845,7 @@ def forward(self, x):
@torch.no_grad()
def new(cls, module, w_qdq, a_qdq):
weight = module.weight.data
if module.bias is not None:
if hasattr(module, 'bias') and module.bias is not None:
bias = module.bias.data
else:
bias = None
Expand Down Expand Up @@ -946,6 +947,126 @@ def __repr__(self):
)


class LlmcDeepSeekV2MoEGate(nn.Module):
def __init__(self, module):
super().__init__()
self.config = module.config
self.top_k = module.config.num_experts_per_tok
self.n_routed_experts = module.config.n_routed_experts
self.routed_scaling_factor = module.config.routed_scaling_factor
self.scoring_func = module.config.scoring_func
self.alpha = module.config.aux_loss_alpha
self.seq_aux = module.config.seq_aux
self.topk_method = module.config.topk_method
self.n_group = module.config.n_group
self.topk_group = module.config.topk_group

# topk selection algorithm
self.norm_topk_prob = module.config.norm_topk_prob
self.gating_dim = module.config.hidden_size
self.fc = nn.Linear(self.gating_dim, self.n_routed_experts, bias=False)
self.fc.weight = module.weight

@property
def weight(self):
return self.fc.weight

def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# compute gating score
hidden_states = hidden_states.view(-1, h)
org_dtype = self.fc.weight.dtype
self.fc.weight.data = self.fc.weight.data.to(torch.float32)
logits = self.fc(hidden_states.type(torch.float32))
self.fc.weight.data = self.fc.weight.data.to(org_dtype)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1, dtype=torch.float32)
else:
raise NotImplementedError(
f'insupportable scoring function for MoE gating: {self.scoring_func}'
)

# select top-k experts
if self.topk_method == 'greedy':
topk_weight, topk_idx = torch.topk(
scores, k=self.top_k, dim=-1, sorted=False
)
elif self.topk_method == 'group_limited_greedy':
group_scores = (
scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(
group_scores, k=self.topk_group, dim=-1, sorted=False
)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(
bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
)
.reshape(bsz * seq_len, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weight, topk_idx = torch.topk(
tmp_scores, k=self.top_k, dim=-1, sorted=False
)

# norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
else:
topk_weight = topk_weight * self.routed_scaling_factor
# expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(
bsz, self.n_routed_experts, device=hidden_states.device
)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
dim=1
).mean() * self.alpha
else:
mask_ce = F.one_hot(
topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None

return topk_idx, topk_weight, aux_loss

@classmethod
@torch.no_grad()
def new(cls, module):
new_module = cls(module)
new_module.zeros_shape = None
new_module.zeros_dtype = None
return new_module

def __repr__(self):
return (
'LlmcDeepSeekV2MoEGate('
+ f'fc={self.fc})'
)


class VllmRealQuantLinear(nn.Module):
def __init__(self, weight, bias, scales, input_scale, need_pack):
super().__init__()
Expand Down Expand Up @@ -1308,6 +1429,10 @@ def __repr__(self):
'DeepseekV2': LlmcDeepseekAttention
}

_LLMC_MOE_GATE_MAP_ = {
'DeepseekV2': LlmcDeepSeekV2MoEGate
}

_REALQUANT_LINEAR_MAP_ = {
'vllm_quant': VllmRealQuantLinear,
'sgl_quant': SglRealQuantLinear,
Expand Down
9 changes: 3 additions & 6 deletions llmc/compression/quantization/osplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ def __init__(self, model, quant_config, input, padding_mask, config, modality='l
super().__init__(model, quant_config, input, padding_mask, config, modality)

@torch.no_grad()
def filter_subset(self, layers_dict, prev_op):
def filter_subset(self, prev_op):
if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
if 'mlp.experts.0.gate_proj' in list(layers_dict.keys()):
return False
else:
return True
return True
else:
return False

Expand Down Expand Up @@ -196,7 +193,7 @@ def subset_transform(
logger.info('Cannot apply scale. Do not transform this subset.')
return

if not self.filter_subset(layers_dict, prev_op):
if not self.filter_subset(prev_op):
logger.info('Do not transform this subset.')
return

Expand Down
24 changes: 19 additions & 5 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def preprocess(self):
self.rotate_head(self.Q)

# for vlm model
if hasattr(self.model, 'vlm_model'):
if hasattr(self.model, 'vlm_model') and self.model.vlm_model is not None:
logger.info('For vlm model, quarot need rotate last layer in projector.')
"""
txt_input img_input
Expand Down Expand Up @@ -116,6 +116,19 @@ def subset_transform(self, block, subset):

layers = list(layers_dict.values())

if 'skip_rotate' in subset and subset['skip_rotate']:
return

if 'need_rotate_alone' in subset and subset['need_rotate_alone']:
assert isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_))
pre_layers = subset['pre_layers']
post_layers = subset['post_layers']
for layer in pre_layers + post_layers:
layer = layer.cuda()
self.fuse_ln_fcs(prev_op[0], layers)
self.rotate_pre_layers(pre_layers, self.Q)
self.rotate_post_layers(post_layers, self.Q)

if isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
self.fuse_ln_fcs(prev_op[0], layers)
self.rotate_pre_layers(layers, self.Q)
Expand All @@ -133,10 +146,11 @@ def subset_transform(self, block, subset):
logger.info(f'{self.Q.shape}')
self.rotate_post_layers(layers, self.Q, exact_had=False)
if self.online_rotate:
apply_exact_had_to_linear(
prev_op[0], had_dim=self.head_dim, output=True
)
apply_exact_had_to_linear(layers[0], had_dim=-1, output=False)
if prev_op[0] is not None:
apply_exact_had_to_linear(
prev_op[0], had_dim=self.head_dim, output=True
)
apply_exact_had_to_linear(layers[0], had_dim=-1, output=False)

@torch.no_grad()
def save_model(self, path):
Expand Down
Loading
Loading