Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support chatglm4v model and llava support img-txt txt calib data when… #222

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .deepseekv2 import DeepseekV2
from .falcon import Falcon
from .gemma2 import Gemma2
from .glm4v import GLM4V
from .internlm2 import InternLM2
from .internomni import InternOmni
from .internvl2 import InternVL2
Expand Down
70 changes: 70 additions & 0 deletions llmc/models/glm4v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from loguru import logger
from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM

from llmc.utils.registry_factory import MODEL_REGISTRY

from .chatglm import ChatGLM


@MODEL_REGISTRY
class GLM4V(ChatGLM):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def build_model(self):
self.vlm_model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
if not self.use_cache:
self.vlm_model_config.use_cache = False
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
self.vlm_model = AutoModelForCausalLM.from_pretrained(
self.model_path,
config=self.vlm_model_config,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
self.vision_model = self.vlm_model.transformer.vision
self.projector = self.vlm_model.transformer.vision.linear_proj
self.model = self.vlm_model
self.model_config = self.vlm_model_config

def batch_process(self, img_qas, calib_or_eval='eval'):
assert calib_or_eval == 'calib' or calib_or_eval == 'eval'
messages = []
answers = []
for idx in range(len(img_qas)):
img_path = img_qas[idx]['img']
if img_path is not None:
image = Image.open(img_path).convert('RGB')
message = [
{
'role': 'user',
'image': image,
'content': img_qas[idx]['question'],
}
]
else:
message = [{'role': 'user', 'content': img_qas[idx]['question']}]
messages.append(message)
answers.append(img_qas[idx]['answer'])
inputs = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors='pt',
return_dict=True,
padding=True,
)
if calib_or_eval == 'calib' and self.config['calib'].get('add_answer', False):
raise Exception(
'glm4v not support add_answer. '
'Maybe you can modify tokenization_chatglm.py in model path.'
)
if calib_or_eval == 'calib':
logger.info(f'Calib data is:\n{inputs}')

inputs = inputs.to(next(self.vlm_model.parameters()).dtype)
return inputs
34 changes: 22 additions & 12 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,28 @@ def batch_process(self, img_qas, calib_or_eval='eval'):
answers = []
for idx in range(len(img_qas)):
img_path = img_qas[idx]['img']
image = Image.open(img_path)
message = [
{
'role': 'user',
'content': [
{'type': 'image'},
{'type': 'text', 'text': img_qas[idx]['question']}
]
}
]
if img_path is not None:
image = Image.open(img_path)
message = [
{
'role': 'user',
'content': [
{'type': 'image'},
{'type': 'text', 'text': img_qas[idx]['question']}
]
}
]
images.append(image)
else:
message = [
{
'role': 'user',
'content': [
{'type': 'text', 'text': img_qas[idx]['question']}
]
}
]
messages.append(message)
images.append(image)
answers.append(img_qas[idx]['answer'])
texts = [
self.processor.apply_chat_template(messages[n], add_generation_prompt=True)
Expand All @@ -67,7 +77,7 @@ def batch_process(self, img_qas, calib_or_eval='eval'):

inputs = self.processor(
text=texts,
images=images,
images=images if len(images) else None,
padding=True,
return_tensors='pt'
).to(next(self.vlm_model.parameters()).dtype) # noqa
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ more_itertools
qtorch
einops
qwen-vl-utils
tiktoken
Loading