Skip to content

Commit

Permalink
support some vlm models
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Aug 25, 2024
1 parent 5daf093 commit aed2595
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,3 @@ quant:
save:
save_trans: False
save_path: ./save
tokenizer_file_substring: ["token"]
1 change: 0 additions & 1 deletion configs/quantization/Awq/awq_w_only_mix_bits_1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,3 @@ quant:
save:
save_trans: False
save_path: ./save
tokenizer_file_substring: ["token"]
1 change: 0 additions & 1 deletion configs/quantization/Awq/awq_w_only_mix_bits_2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,3 @@ quant:
save:
save_trans: False
save_path: ./save
tokenizer_file_substring: ["token"]
11 changes: 6 additions & 5 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def apply_shift(self, shifts, prev_op, layers):
def scale_fc_fc(self, fc1, fc2, scales):
scales = scales.to(fc1.weight.device)
if fc1.out_features == fc2.in_features * 3:
num_heads = self.model.get_model_config().to_dict().get('n_head', None)
num_heads = self.model.get_num_attention_heads()
fc1.weight.t_()
org_shape = fc1.weight.shape
fc1.weight.data = fc1.weight.data.reshape(org_shape[0] * num_heads, 3, -1)
Expand Down Expand Up @@ -798,7 +798,8 @@ def deploy(self, quant_format):

@torch.no_grad()
def copy_tokenizer(self, path):
for substring in self.config.save.get('tokenizer_file_substring', ['token']):
for substring in self.config.save.get('tokenizer_file_substring',
['token', 'merges', 'vocab']):
copy_files(self.config.model.path, path, substring)
logger.info('copy tokenizer done --')

Expand All @@ -818,9 +819,9 @@ def save_model(self, path):
return
if self.online_rotate:
self.contiguous_params()
if self.config.model.type == 'Llava':
self.model.llava_model.language_model = self.model.get_model()
self.model.llava_model.save_pretrained(path)
if self.config.model.type in ['Llava', 'InternVL2']:
self.model.vlm_model.language_model = self.model.get_model()
self.model.vlm_model.save_pretrained(path)
logger.info('save model done --')
self.copy_tokenizer(path)
copy_files(self.config.model.path, path, 'preprocessor_config')
Expand Down
3 changes: 3 additions & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from .falcon import Falcon
from .gemma2 import Gemma2
from .internlm2 import InternLM2
from .internvl2 import InternVL2
from .llama import Llama
from .llava import Llava
from .minicpm import MiniCPM
from .mistral import Mistral
from .mixtral import Mixtral
from .opt import Opt
from .phi import Phi
from .qwen import Qwen
from .qwen2 import Qwen2
from .qwenvl import QwenVL
from .smollm import SmolLM
from .stablelm import StableLm
from .starcoder import Starcoder
29 changes: 29 additions & 0 deletions llmc/models/internvl2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from loguru import logger
from transformers import AutoConfig, AutoModelForCausalLM

from llmc.utils.registry_factory import MODEL_REGISTRY

from .internlm2 import InternLM2


@MODEL_REGISTRY
class InternVL2(InternLM2):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)

def build_model(self):
self.vlm_model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
if hasattr(self.vlm_model_config, 'use_cache'):
self.vlm_model_config.use_cache = False
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
self.vlm_model = AutoModelForCausalLM.from_pretrained(
self.model_path,
config=self.vlm_model_config,
trust_remote_code=True,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
)
self.model = self.vlm_model.language_model
self.model_config = self.vlm_model_config.llm_config
12 changes: 7 additions & 5 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)

def build_model(self):
self.model_config = AutoConfig.from_pretrained(
self.vlm_model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
self.model_config.text_config.use_cache = False
self.llava_model = LlavaForConditionalGeneration.from_pretrained(
self.vlm_model_config.text_config.use_cache = False
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
self.vlm_model = LlavaForConditionalGeneration.from_pretrained(
self.model_path,
config=self.model_config,
config=self.vlm_model_config,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
)
self.model = self.llava_model.language_model
self.model = self.vlm_model.language_model
self.model_config = self.vlm_model_config.text_config
85 changes: 85 additions & 0 deletions llmc/models/qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel


@MODEL_REGISTRY
class Qwen(BaseModel):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)

def find_blocks(self):
self.blocks = self.model.transformer.h

def find_embed_layers(self):
self.wte = self.model.transformer.wte
self.rotary_emb = self.model.transformer.rotary_emb

def find_block_name(self):
self.block_name_prefix = 'transformer.h'

def get_embed_layers(self):
return [self.wte, self.rotary_emb]

def get_head_layers(self):
return [self.model.lm_head]

def get_pre_head_layernorm_layers(self):
return [self.model.transformer.ln_f]

def get_layers_except_blocks(self):
return [self.wte,
self.rotary_emb,
self.model.transformer.ln_f,
self.model.lm_head]

def has_bias(self):
return False

def get_layernorms_in_block(self, block):
return {
'ln_1': block.ln_1,
'ln_2': block.ln_2,
}

def get_num_attention_heads(self):
return self.model_config.num_attention_heads

def get_subsets_in_block(self, block):
return [
{
'layers': {
'attn.c_attn': block.attn.c_attn
},
'prev_op': [block.ln_1],
'input': ['attn.c_attn'],
'inspect': block.attn,
'has_kwargs': True,
},
{
'layers': {'attn.c_proj': block.attn.c_proj},
'prev_op': [block.attn.c_attn],
'input': ['attn.c_proj'],
'inspect': block.attn.c_proj,
'has_kwargs': False,
},
{
'layers': {
'mlp.w1': block.mlp.w1,
'mlp.w2': block.mlp.w2,
},
'prev_op': [block.ln_2],
'input': ['mlp.w1'],
'inspect': block.mlp,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.c_proj': block.mlp.c_proj},
'prev_op': [block.mlp.w1],
'input': ['mlp.c_proj'],
'inspect': block.mlp.c_proj,
'has_kwargs': False,
'is_mlp': True,
},
]
30 changes: 30 additions & 0 deletions llmc/models/qwenvl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from loguru import logger
from transformers import AutoConfig, AutoModelForCausalLM

from llmc.utils.registry_factory import MODEL_REGISTRY

from .qwen import Qwen


@MODEL_REGISTRY
class QwenVL(Qwen):
def __init__(self, model_path, torch_dtype):
super().__init__(model_path, torch_dtype)

def build_model(self):
self.vlm_model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
if hasattr(self.vlm_model_config, 'use_cache'):
self.vlm_model_config.use_cache = False
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
self.vlm_model = AutoModelForCausalLM.from_pretrained(
self.model_path,
config=self.vlm_model_config,
trust_remote_code=True,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
)
self.model = self.vlm_model
self.model_config = self.vlm_model_config
self.vision_model = self.vlm_model.transformer.visual

0 comments on commit aed2595

Please sign in to comment.