Skip to content

Commit

Permalink
Merge pull request #33 from ModelTC/slm
Browse files Browse the repository at this point in the history
Support SLM, e.g., Phi, Qwen2, Gemma2, Internlm2.5, MiniCPM, SmolLM, …
  • Loading branch information
llmc-reviewer authored Aug 20, 2024
2 parents dc1b1dc + 7098077 commit 080d43b
Show file tree
Hide file tree
Showing 10 changed files with 410 additions and 5 deletions.
3 changes: 2 additions & 1 deletion llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,8 @@ def fuse_ln_fcs(self, ln, fcs):
fc.bias = torch.nn.Parameter(
torch.zeros(fc.out_features, dtype=torch.float64)
)
fc.bias.data = fc.bias.data.double() + torch.matmul(W, ln.bias.double())
fc.bias.data = fc.bias.data.double().to(device=W.device) \
+ torch.matmul(W, ln.bias.double())
fc.bias.data = fc.bias.data.to(fc_dtype)

def remove_mean_from_embed(self):
Expand Down
39 changes: 37 additions & 2 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from transformers.models.mistral.modeling_mistral import MistralRMSNorm
from transformers.models.mixtral.modeling_mixtral import MixtralRMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm

try:
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
except Exception:
logger.info(
'Gemma2RMSNorm not installed. '
'If you need it, please update your transformers lib.'
)

class Gemma2RMSNorm(nn.Module):
pass
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

try:
Expand All @@ -17,7 +28,7 @@
from .hadamard_utils import matmul_hadU_cuda
except Exception:
logger.info(
'fast_hadamard_transform not installed.'
'fast_hadamard_transform not installed. '
'If you need it, please install it firstly.'
)

Expand Down Expand Up @@ -122,7 +133,10 @@ def forward(self, hidden_states):
@classmethod
@torch.no_grad()
def new(cls, module):
eps = module.variance_epsilon
if hasattr(module, 'eps'):
eps = module.eps
else:
eps = module.variance_epsilon
weight = module.weight
new_module = cls(weight, eps)
return new_module
Expand Down Expand Up @@ -163,6 +177,22 @@ def __repr__(self):
return 'LlmcInternLM2RMSNorm()'


class LlmcGemma2RMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)

def __repr__(self):
return 'LlmcGemma2RMSNorm()'


class LlmcMiniCPMRMSNorm(LlmcLlamaRMSNorm):
def __init__(self, weight, eps=1e-6):
super().__init__(weight, eps)

def __repr__(self):
return 'LlmcMiniCPMRMSNorm()'


class OriginFloatLinear(nn.Module):
def __init__(self, weight, bias, ori_module):
super().__init__()
Expand Down Expand Up @@ -616,6 +646,7 @@ def __repr__(self):
MixtralRMSNorm,
Qwen2RMSNorm,
LlamaRMSNorm,
Gemma2RMSNorm,
nn.LayerNorm,
]
_TRANSFORMERS_LINEAR_TYPES_ = [nn.Linear]
Expand All @@ -627,6 +658,8 @@ def __repr__(self):
'Mixtral': LlmcMixtralRMSNorm,
'Interlm2': LlmcInternLM2RMSNorm,
'Qwen2': LlmcQwen2RMSNorm,
'Gemma2': LlmcGemma2RMSNorm,
'MiniCPM': LlmcMiniCPMRMSNorm,
'Starcoder': LlmcLayerNorm,
'Opt': LlmcLayerNorm,
'Bloom': LlmcLayerNorm,
Expand All @@ -641,6 +674,8 @@ def __repr__(self):
LlmcMistralRMSNorm,
LlmcMixtralRMSNorm,
LlmcInternLM2RMSNorm,
LlmcGemma2RMSNorm,
LlmcMiniCPMRMSNorm,
]


Expand Down
19 changes: 17 additions & 2 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(self, model, quant_config, input, config):
self.preprocess()

def preprocess(self):
assert self.config['model']['type'] in ['Opt', 'Llama', 'Qwen2']
assert self.config['model']['type'] in [
'Opt', 'Llama', 'Qwen2', 'InternLM2',
'MiniCPM', 'StableLm', 'SmolLM']
# if self.config["model"]["type"] in ["Opt"]:
if torch.equal(
self.model.get_head_layers()[0].weight,
Expand Down Expand Up @@ -83,6 +85,16 @@ def block_transform(self, block):
logger.info(f'block:{block}')
logger.info(f'End transform the {self.block_idx+1}-th block')

def bake_mean_into_linear(self, linear):
linear_dtype = linear.weight.dtype
W_ = linear.weight.data.double()
linear.weight.data = W_ - W_.mean(dim=-2, keepdim=True)
linear.weight.data = linear.weight.data.to(linear_dtype)
if linear.bias is not None:
b_ = linear.bias.data.double()
linear.bias.data = b_ - b_.mean()
linear.bias.data = linear.bias.data.to(linear_dtype)

@torch.no_grad()
def subset_transform(self, block, subset):
prev_op = subset['prev_op']
Expand All @@ -97,14 +109,17 @@ def subset_transform(self, block, subset):
self.fuse_ln_fcs(prev_op[0], layers)
self.rotate_pre_layers(layers, self.Q)
else:
if self.config['model']['type'] in ['Opt']:
if self.config['model']['type'] in ['Opt', 'StableLm']:
self.bake_mean_into_linear(layers[0])

if 'is_mlp' in subset and subset['is_mlp']:
self.rotate_post_layers(
layers, self.Q, exact_had=True if self.online_rotate else False
)
else:
for n, m in layers_dict.items():
logger.info(f'layer: {n} {m.weight.shape}')
logger.info(f'{self.Q.shape}')
self.rotate_post_layers(layers, self.Q, exact_had=False)
if self.online_rotate:
apply_exact_had_to_linear(
Expand Down
4 changes: 4 additions & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from .internlm2 import InternLM2
from .llama import Llama
from .llava import Llava
from .minicpm import MiniCPM
from .mistral import Mistral
from .mixtral import Mixtral
from .opt import Opt
from .phi import Phi
from .qwen2 import Qwen2
from .smollm import SmolLM
from .stablelm import StableLm
from .starcoder import Starcoder
30 changes: 30 additions & 0 deletions llmc/models/gemma2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
from loguru import logger

from llmc.utils.registry_factory import MODEL_REGISTRY

try:
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
except Exception:
logger.warning('Gemma2 not found')
from types import MethodType

import torch.nn as nn

from .base_model import BaseModel


def gemma2_rms_norm_forward(self, x):
output = self._norm(x.float())
output = output * self.weight.float()
return output.type_as(x)


@MODEL_REGISTRY
class Gemma2(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)
for m in self.model.modules():
if isinstance(m, Gemma2RMSNorm):
w = m.weight.data
del m.weight
m.weight = nn.Parameter(w + 1.0)
m.forward = MethodType(gemma2_rms_norm_forward, m)

def find_blocks(self):
self.blocks = self.model.model.layers
Expand All @@ -21,6 +43,12 @@ def find_block_name(self):
def get_embed_layers(self):
return [self.embed_tokens]

def get_head_layers(self):
return [self.model.lm_head]

def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

Expand Down Expand Up @@ -62,12 +90,14 @@ def get_subsets_in_block(self, block):
'input': ['mlp.gate_proj'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.down_proj': block.mlp.down_proj},
'prev_op': [block.mlp.up_proj],
'input': ['mlp.down_proj'],
'inspect': block.mlp.down_proj,
'has_kwargs': False,
'is_mlp': True,
},
]
11 changes: 11 additions & 0 deletions llmc/models/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from llmc.compression.quantization.module_utils import _TRANSFORMERS_LN_TYPES_
from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel
Expand All @@ -7,6 +8,8 @@
class InternLM2(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)
global _TRANSFORMERS_LN_TYPES_
_TRANSFORMERS_LN_TYPES_ += [type(self.model.model.norm)]

def find_blocks(self):
self.blocks = self.model.model.layers
Expand All @@ -20,6 +23,12 @@ def find_block_name(self):
def get_embed_layers(self):
return [self.tok_embeddings]

def get_head_layers(self):
return [self.model.output]

def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.tok_embeddings, self.model.model.norm, self.model.output]

Expand Down Expand Up @@ -57,12 +66,14 @@ def get_subsets_in_block(self, block):
'input': ['feed_forward.w1'],
'inspect': block.feed_forward,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'feed_forward.w2': block.feed_forward.w2},
'prev_op': [block.feed_forward.w3],
'input': ['feed_forward.w2'],
'inspect': block.feed_forward.w2,
'has_kwargs': False,
'is_mlp': True,
},
]
84 changes: 84 additions & 0 deletions llmc/models/minicpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from llmc.compression.quantization.module_utils import _TRANSFORMERS_LN_TYPES_
from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel


@MODEL_REGISTRY
class MiniCPM(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)
global _TRANSFORMERS_LN_TYPES_
_TRANSFORMERS_LN_TYPES_ += [type(self.model.model.norm)]

def find_blocks(self):
self.blocks = self.model.model.layers

def find_embed_layers(self):
self.embed_tokens = self.model.model.embed_tokens

def find_block_name(self):
self.block_name_prefix = 'model.layers'
self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}

def get_embed_layers(self):
return [self.embed_tokens]

def get_head_layers(self):
return [self.model.lm_head]

def get_pre_head_layernorm_layers(self):
return [self.model.model.norm]

def get_layers_except_blocks(self):
return [self.embed_tokens, self.model.model.norm, self.model.lm_head]

def has_bias(self):
return False

def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
}

def get_subsets_in_block(self, block):
return [
{
'layers': {
'self_attn.q_proj': block.self_attn.q_proj,
'self_attn.k_proj': block.self_attn.k_proj,
'self_attn.v_proj': block.self_attn.v_proj,
},
'prev_op': [block.input_layernorm],
'input': ['self_attn.q_proj'],
'inspect': block.self_attn,
'has_kwargs': True,
},
{
'layers': {'self_attn.o_proj': block.self_attn.o_proj},
'prev_op': [block.self_attn.v_proj],
'input': ['self_attn.o_proj'],
'inspect': block.self_attn.o_proj,
'has_kwargs': False,
},
{
'layers': {
'mlp.gate_proj': block.mlp.gate_proj,
'mlp.up_proj': block.mlp.up_proj,
},
'prev_op': [block.post_attention_layernorm],
'input': ['mlp.gate_proj'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.down_proj': block.mlp.down_proj},
'prev_op': [block.mlp.up_proj],
'input': ['mlp.down_proj'],
'inspect': block.mlp.down_proj,
'has_kwargs': False,
'is_mlp': True,
},
]
Loading

0 comments on commit 080d43b

Please sign in to comment.