From 7b5e918d7d59382dcc7d62aa93eafc6d11c27e43 Mon Sep 17 00:00:00 2001 From: Yang Yong Date: Fri, 22 Nov 2024 17:02:14 +0800 Subject: [PATCH] support chatglm4v model and llava support img-txt txt calib data when bs=1 (#222) --- llmc/models/__init__.py | 1 + llmc/models/glm4v.py | 70 ++++++++++++++++++++++++++++++++++++++++ llmc/models/llava.py | 34 ++++++++++++------- requirements/runtime.txt | 1 + 4 files changed, 94 insertions(+), 12 deletions(-) create mode 100644 llmc/models/glm4v.py diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index e3da94a0..54fbd165 100644 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -3,6 +3,7 @@ from .deepseekv2 import DeepseekV2 from .falcon import Falcon from .gemma2 import Gemma2 +from .glm4v import GLM4V from .internlm2 import InternLM2 from .internomni import InternOmni from .internvl2 import InternVL2 diff --git a/llmc/models/glm4v.py b/llmc/models/glm4v.py new file mode 100644 index 00000000..dae45bc1 --- /dev/null +++ b/llmc/models/glm4v.py @@ -0,0 +1,70 @@ +from loguru import logger +from PIL import Image +from transformers import AutoConfig, AutoModelForCausalLM + +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .chatglm import ChatGLM + + +@MODEL_REGISTRY +class GLM4V(ChatGLM): + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + + def build_model(self): + self.vlm_model_config = AutoConfig.from_pretrained( + self.model_path, trust_remote_code=True + ) + if not self.use_cache: + self.vlm_model_config.use_cache = False + logger.info(f'self.vlm_model_config : {self.vlm_model_config}') + self.vlm_model = AutoModelForCausalLM.from_pretrained( + self.model_path, + config=self.vlm_model_config, + torch_dtype=self.torch_dtype, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + self.vision_model = self.vlm_model.transformer.vision + self.projector = self.vlm_model.transformer.vision.linear_proj + self.model = self.vlm_model + self.model_config = self.vlm_model_config + + def batch_process(self, img_qas, calib_or_eval='eval'): + assert calib_or_eval == 'calib' or calib_or_eval == 'eval' + messages = [] + answers = [] + for idx in range(len(img_qas)): + img_path = img_qas[idx]['img'] + if img_path is not None: + image = Image.open(img_path).convert('RGB') + message = [ + { + 'role': 'user', + 'image': image, + 'content': img_qas[idx]['question'], + } + ] + else: + message = [{'role': 'user', 'content': img_qas[idx]['question']}] + messages.append(message) + answers.append(img_qas[idx]['answer']) + inputs = self.tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_tensors='pt', + return_dict=True, + padding=True, + ) + if calib_or_eval == 'calib' and self.config['calib'].get('add_answer', False): + raise Exception( + 'glm4v not support add_answer. ' + 'Maybe you can modify tokenization_chatglm.py in model path.' + ) + if calib_or_eval == 'calib': + logger.info(f'Calib data is:\n{inputs}') + + inputs = inputs.to(next(self.vlm_model.parameters()).dtype) + return inputs diff --git a/llmc/models/llava.py b/llmc/models/llava.py index 7b78c619..b25ec0c9 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -40,18 +40,28 @@ def batch_process(self, img_qas, calib_or_eval='eval'): answers = [] for idx in range(len(img_qas)): img_path = img_qas[idx]['img'] - image = Image.open(img_path) - message = [ - { - 'role': 'user', - 'content': [ - {'type': 'image'}, - {'type': 'text', 'text': img_qas[idx]['question']} - ] - } - ] + if img_path is not None: + image = Image.open(img_path) + message = [ + { + 'role': 'user', + 'content': [ + {'type': 'image'}, + {'type': 'text', 'text': img_qas[idx]['question']} + ] + } + ] + images.append(image) + else: + message = [ + { + 'role': 'user', + 'content': [ + {'type': 'text', 'text': img_qas[idx]['question']} + ] + } + ] messages.append(message) - images.append(image) answers.append(img_qas[idx]['answer']) texts = [ self.processor.apply_chat_template(messages[n], add_generation_prompt=True) @@ -67,7 +77,7 @@ def batch_process(self, img_qas, calib_or_eval='eval'): inputs = self.processor( text=texts, - images=images, + images=images if len(images) else None, padding=True, return_tensors='pt' ).to(next(self.vlm_model.parameters()).dtype) # noqa diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 4d53ebb4..ddf31179 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -28,3 +28,4 @@ more_itertools qtorch einops qwen-vl-utils +tiktoken