Skip to content

Commit

Permalink
v0.5.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Tongjilibo committed Sep 28, 2024
1 parent e4bc362 commit c6db7eb
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ pip install git+https://github.com/Tongjilibo/bert4torch
### 4.1 版本历史
|更新日期| bert4torch | torch4keras | 版本说明 |
|------| ---------------- | ----------------- |----------- |
|20240928| 0.5.4 | 0.2.7 | 【新功能】增加deepseek系列、MiniCPM、MiniCPMV、llama3.2、Qwen2.5;支持device_map=auto;【修复】修复batch_generate和n>1的bug|
|20240814| 0.5.3 | 0.2.6 | 【新功能】增加llama3.1/Yi1.5;自动选择从hfmirror下载;支持命令行参数`bert4torch-llm-server`|
|20240801| 0.5.2 | 0.2.5 | 【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug|
|20240619| 0.5.1 | 0.2.4 | 增加Qwen1.5, Qwen2, glm4; 增加SWA/convert_lm_logits_dtype;调整各个trainer(重点DPOTrainer), generation中segment_ids, repetition_penalty需带query, RMSNorm中转类型bug|

[更多版本](https://github.com/Tongjilibo/bert4torch/blob/master/docs/Update.md)

Expand Down
13 changes: 11 additions & 2 deletions bert4torch/models/minicpmv/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from bert4torch.models.llama import LLaMA
from bert4torch.models.base import BERT_BASE
from bert4torch.snippets import is_transformers_available, DottableDict, inference_mode
import inspect


if is_transformers_available():
Expand Down Expand Up @@ -40,6 +41,7 @@ def init_vision_module(self):
setattr(model, 'embed_dim', model.embeddings.embed_dim)
setattr(model, 'patch_size', model.embeddings.patch_size)

self.vlm_tgt_sizes = True if 'tgt_sizes' in inspect.signature(model).parameters else False
return model

def init_resampler(self, embed_dim, vision_dim):
Expand Down Expand Up @@ -105,11 +107,17 @@ def get_vllm_embedding(self, data):
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx]).last_hidden_state
inputs_ = {'patch_attention_mask': patch_attn_mask[start_idx:end_idx]}
if self.vlm_tgt_sizes:
inputs_['tgt_sizes'] = tgt_sizes[start_idx:end_idx]
tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx], **inputs_).last_hidden_state
hs.append(tmp_hs)
vision_embedding = torch.cat(hs, dim=0)
else:
vision_embedding = self.vpm(all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state
inputs_ = {'patch_attention_mask': patch_attn_mask}
if self.vlm_tgt_sizes:
inputs_['tgt_sizes'] = tgt_sizes
vision_embedding = self.vpm(all_pixel_values, **inputs_).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)

start = 0
Expand Down Expand Up @@ -265,6 +273,7 @@ def init_vision_module(self):

setattr(model, 'embed_dim', model.embeddings.embed_dim)
setattr(model, 'patch_size', model.embeddings.patch_size)
self.vlm_tgt_sizes = True if 'tgt_sizes' in inspect.signature(model).parameters else False

return model

Expand Down
32 changes: 23 additions & 9 deletions bert4torch/pipelines/chatv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from argparse import REMAINDER, ArgumentParser
from copy import deepcopy
from PIL import Image
import inspect


if is_fastapi_available():
Expand Down Expand Up @@ -197,15 +198,28 @@ def build_prompt(

prompts_lists.append(self.processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True))
input_images_lists.append(history_images + image)

inputs = self.processor(
prompts_lists,
input_images_lists,
max_slice_nums=kwargs.get('max_slice_nums'),
use_image_id=kwargs.get('use_image_id'),
return_tensors="pt",
max_length=kwargs.get('max_inp_length'),
).to(self.device)

if 'max_slice_nums' in inspect.signature(self.processor).parameters:
# MiniCPM-V-2_6
inputs = self.processor(
prompts_lists,
input_images_lists,
max_slice_nums=kwargs.get('max_slice_nums'),
use_image_id=kwargs.get('use_image_id'),
return_tensors="pt",
max_length=kwargs.get('max_inp_length'),
).to(self.device)
else:
# MiniCPM-Llama3-V-2_5, 仅接受单张照片预测
if len(prompts_lists) > 1:
raise ValueError('`MiniCPM-Llama3-V-2_5` not support batch inference.')
inputs = self.processor(
prompts_lists[0],
input_images_lists[0],
return_tensors="pt",
max_length=kwargs.get('max_inp_length'),
).to(self.device)
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'], dtype=bool)

inputs.pop("image_sizes")
return inputs
1 change: 1 addition & 0 deletions docs/History.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 更新历史

- **20240928**:【新功能】增加deepseek系列、MiniCPM、MiniCPMV、llama3.2、Qwen2.5;支持device_map=auto;【修复】修复batch_generate和n>1的bug
- **20240814**:【新功能】增加llama3.1/Yi1.5;自动选择从hfmirror下载;支持命令行参数`bert4torch-llm-server`|
- **20240801**:【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug
- **20240619**:增加Qwen1.5, Qwen2, glm4; 增加SWA/convert_lm_logits_dtype;调整各个trainer(重点DPOTrainer), generation中segment_ids, repetition_penalty需带query
Expand Down
1 change: 1 addition & 0 deletions docs/Update.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

|更新日期| bert4torch版本 | torch4keras版本 | 版本说明 |
|------| ---------------- | ----------------- |----------- |
|20240928| 0.5.4 | 0.2.7 | 【新功能】增加deepseek系列、MiniCPM、MiniCPMV、llama3.2、Qwen2.5;支持device_map=auto;【修复】修复batch_generate和n>1的bug|
|20240814| 0.5.3 | 0.2.6 | 【新功能】增加llama3.1/Yi1.5;自动选择从hfmirror下载;支持命令行参数`bert4torch-llm-server`|
|20240801| 0.5.2 | 0.2.5 | 【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug|
|20240619| 0.5.1 | 0.2.4 | 增加Qwen1.5, Qwen2, glm4; 增加SWA/convert_lm_logits_dtype;调整各个trainer(重点DPOTrainer), generation中segment_ids, repetition_penalty需带query, RMSNorm中转类型bug|
Expand Down
6 changes: 3 additions & 3 deletions examples/basic/MiniCPM/basic_language_model_llama_MiniCPMV.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from bert4torch.snippets import log_info


# 'E:/data/pretrain_ckpt/MiniCPM/MiniCPM-Llama3-V-2_5'
# 'E:/data/pretrain_ckpt/MiniCPM/MiniCPM-V-2_6'
demo = MiniCPMV('E:/data/pretrain_ckpt/MiniCPM/MiniCPM-Llama3-V-2_5')
# E:/data/pretrain_ckpt/MiniCPM/MiniCPM-Llama3-V-2_5
# E:/data/pretrain_ckpt/MiniCPM/MiniCPM-V-2_6
demo = MiniCPMV('E:/data/pretrain_ckpt/MiniCPM/MiniCPM-V-2_6')
query1 = '介绍一下这张图片的内容?'
query2 = '图片内容和基金产品相关吗?'
image1 = Image.open('./test_local/资料概要.png').convert('RGB')
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

setup(
name='bert4torch',
version='v0.5.3',
version='v0.5.4',
description='an elegant bert4torch',
long_description=long_description,
long_description_content_type="text/markdown",
license='MIT Licence',
url='https://github.com/Tongjilibo/bert4torch',
author='Tongjilibo',
install_requires=['numpy', 'tqdm', 'torch>1.6', 'torch4keras==0.2.6', 'six'],
install_requires=['numpy', 'tqdm', 'torch>1.6', 'torch4keras==0.2.7', 'six'],
packages=find_packages(),
entry_points={"console_scripts": ["bert4torch-llm-server = bert4torch.pipelines.chat:main"]},

Expand Down

0 comments on commit c6db7eb

Please sign in to comment.