Skip to content

Commit

Permalink
update internvl2 (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang authored Nov 17, 2024
1 parent 0ea65a8 commit 14bb456
Showing 1 changed file with 16 additions and 75 deletions.
91 changes: 16 additions & 75 deletions llmc/models/internvl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,25 @@ def build_model(self):
self.vlm_model.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)

def batch_process(self, img_qas):
if len(img_qas) == 1:
return self.single_process(img_qas[0])
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
questions = []
pixel_values_list = []
num_patches_list = []
for idx in range(len(img_qas)):
img_path = img_qas[idx]['img']
pixel_values = load_image(img_path, max_num=12).to(
next(self.vlm_model.parameters()).dtype
)
pixel_values_list.append(pixel_values)
num_patches_list.append(pixel_values.size(0))
questions.append(f"<image>\n{img_qas[idx]['question']}")
pixel_values = torch.cat(pixel_values_list, dim=0)
_num_patches_i = []
if img_path is not None:
if not isinstance(img_path, list):
img_path = [img_path]
for img_idx in range(len(img_path)):
pixel_values = load_image(img_path[img_idx], max_num=12).to(
next(self.vlm_model.parameters()).dtype
)
pixel_values_list.append(pixel_values)
_num_patches_i.append(pixel_values.size(0))
num_patches_list.append(_num_patches_i)
questions.append(img_qas[idx]['question'])
pixel_values = torch.cat(pixel_values_list, dim=0) if len(pixel_values_list) > 0 else None
generation_config = dict()

IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
Expand All @@ -163,12 +167,10 @@ def batch_process(self, img_qas):
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
image_tokens = (IMG_START_TOKEN +
IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * num_patches +
IMG_END_TOKEN)
query = query.replace('<image>', image_tokens, 1)
for _num_patches_i in num_patches:
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * _num_patches_i + IMG_END_TOKEN # noqa
query = query.replace('<image>', image_tokens, 1)
queries.append(query)

tokenizer.padding_side = 'left'
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
input_ids = model_inputs['input_ids']
Expand All @@ -183,64 +185,3 @@ def batch_process(self, img_qas):
**generation_config
}
return inputs

def single_process(self, img_qa):
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
num_patches_list = None
pixel_values_list = []
if img_qa['img'] is not None:
if isinstance(img_qa['img'], list):
num_patches_list = []
for img_idx in range(len(img_qa['img'])):
pixel_values = load_image(img_qa['img'][img_idx], max_num=12).to(
next(self.vlm_model.parameters()).dtype
)
pixel_values_list.append(pixel_values)
num_patches_list.append(pixel_values.size(0))
pixel_values = torch.cat(pixel_values_list, dim=0)
else:
pixel_values = load_image(img_qa['img'], max_num=12).to(
next(self.vlm_model.parameters()).dtype
)
else:
pixel_values = None
question = img_qa['question']
if pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
if num_patches_list is None:
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
generation_config = dict()

IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
try:
template = get_conv_template(self.vlm_model.template)
except Exception:
raise Exception(
'InternLM2 conversation.py not be found. '
'Please copy it from model path to llmc/models.'
)
template.system_message = self.vlm_model.system_message
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
for num_patches in num_patches_list:
image_tokens = (IMG_START_TOKEN +
IMG_CONTEXT_TOKEN * self.vlm_model.num_image_token * num_patches +
IMG_END_TOKEN)
query = query.replace('<image>', image_tokens, 1)

model_inputs = tokenizer(query, return_tensors='pt')
input_ids = model_inputs['input_ids']
attention_mask = model_inputs['attention_mask']
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
generation_config['eos_token_id'] = eos_token_id

inputs = {
'pixel_values': pixel_values,
'input_ids': input_ids,
'attention_mask': attention_mask,
**generation_config
}
return inputs

0 comments on commit 14bb456

Please sign in to comment.