Skip to content

Commit

Permalink
合并llama示例, 优化internlm function call
Browse files Browse the repository at this point in the history
  • Loading branch information
Tongjilibo committed Aug 1, 2024
1 parent 19b8251 commit 37f9dec
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 146 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pip install git+https://github.com/Tongjilibo/bert4torch
### 4.1 版本历史
|更新日期| bert4torch | torch4keras | 版本说明 |
|------| ---------------- | ----------------- |----------- |
|20240730| 0.5.2 | 0.2.5 | 【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug|
|20240801| 0.5.2 | 0.2.5 | 【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug|
|20240619| 0.5.1 | 0.2.4 | 增加Qwen1.5, Qwen2, glm4; 增加SWA/convert_lm_logits_dtype;调整各个trainer(重点DPOTrainer), generation中segment_ids, repetition_penalty需带query, RMSNorm中转类型bug|
|20240418| 0.5.0 | 0.2.2 | 修复chatglm3的bug, 修复save_pretrained时多文件的bug,增加CausalLMLoss, 修改deepspeed的传参逻辑,修改Text2Vec的bug, 完善openai client, 增加get_weight_decay_optim_groups|

Expand Down Expand Up @@ -171,6 +171,7 @@ model = build_transformer_model(config_path, checkpoint_path)
| llama | [llama](https://github.com/facebookresearch/llama) | meta| | [`llama-7b`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/llama-7b), [`llama-13b`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/llama-13b)|
| | [llama-2](https://github.com/facebookresearch/llama) | meta| [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), [meta-llama/Llama-2-13b-hf](https://huggingface.co/meta-llama/Llama-2-13b-hf), [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | [`Llama-2-7b-hf`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/Llama-2-7b-hf), [`Llama-2-7b-chat-hf`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/Llama-2-7b-chat-hf), [`Llama-2-13b-hf`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/Llama-2-13b-hf), [`Llama-2-13b-chat-hf`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/Llama-2-13b-chat-hf)|
| | [llama-3](https://github.com/meta-llama/llama3) | meta| [`meta-llama/Meta-Llama-3-8B`](https://huggingface.co/meta-llama/Meta-Llama-3-8B), [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | [`Meta-Llama-3-8B`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/Meta-Llama-3-8B), [`Meta-Llama-3-8B-Instruct`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/Meta-Llama-3-8B-Instruct)|
| | [llama-3.1](https://github.com/meta-llama/llama-models) | meta | | 待添加|
| | [Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)|HFL| |[`chinese_alpaca_plus_7b`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/chinese_alpaca_plus_7b), [`chinese_llama_plus_7b`](https://huggingface.co/Tongjilibo/bert4torch_config/tree/main/chinese_llama_plus_7b)|
| | [Chinese-LLaMA-Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)|HFL| |待添加|
| | [Chinese-LLaMA-Alpaca-3](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)|HFL| |待添加|
Expand Down
33 changes: 30 additions & 3 deletions bert4torch/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self,
n:int=1,
top_k:int=None,
top_p:float=None,
temperature:float=1,
temperature:float=1.0,
repetition_penalty:float=1.0,
min_ends:int=1,
**generation_config):
Expand All @@ -99,15 +99,42 @@ def __init__(self,
self.pad_token_id = pad_token_id # pad_token_id兼容bert4torch和hf的, 如错误则需要显式传入pad_id:int
self.pad_mode = pad_mode
self.device = device

# 生成的样本个数
if not isinstance(n, int) or n <= 0:
raise ValueError(f"`n` has to be a strictly positive integer, but is {n}")
self.n = n

# topk采样
if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k

# top_p采样
if top_p is not None and (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p

# 温度系数
if not isinstance(temperature, (int,float)) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
"scores will be invalid."
)
if isinstance(temperature, (int,float)) and temperature == 0.0:
except_msg += " If you're looking for greedy decoding strategies, set `top_k=1`."
raise ValueError(except_msg)
self.temperature = temperature

# 重复性惩罚系数
if not isinstance(repetition_penalty, (int,float)) or not (repetition_penalty > 0):
raise ValueError(f"`repetition_penalty` has to be a strictly positive float, but is {repetition_penalty}")
self.repetition_penalty = repetition_penalty
self.min_ends = min_ends

self.min_ends = min_ends
self.return_last_token = False
self.return_states = False
# 参数别名:兼容transformers的参数设置
# 参数别名:兼容bert4torch旧示例
self.alias = {'start_id': 'bos_token_id',
'end_id': 'eos_token_id',
'topk': 'top_k',
Expand Down
119 changes: 75 additions & 44 deletions bert4torch/pipelines/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Implements API for LLM in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
Usage: python openai_api.py
Visit http://localhost:8000/docs for documents.
3. web界面快速搭建demo
3. web界面快速搭建demo(gradio+streamlit)
# TODO: 设置return_states=True时候,受到build_prompt影响,很难保证prompt完全复现
这里采用添加self.generation_config['states']['last_token'],是因为推理完成可能是因为到达max_length,未必是遇到了eos
Expand Down Expand Up @@ -1198,48 +1198,57 @@ def process_response_history(self, response, history=None):
@add_start_docstrings(CHAT_START_DOCSTRING)
class InternLM2(ChatBase):
'''internlm2支持function call, 格式如下:
```python
[
{
"name": "track",
"description": "追踪指定股票的实时价格",
"parameters":
{
"type": "object",
"properties": {"symbol":
{
"description": "需要追踪的股票代码"
}
},
"required": []
}
}
]
```
由于_additional_special_tokens为['<|im_start|>', '<|im_end|>', '<|action_start|>', '<|action_end|>', '<|interpreter|>', '<|plugin|>']
在function call时候若skip_special_tokens=True, 则捕捉不到'<|action_start|>', '<|action_end|>', '<|interpreter|>', '<|plugin|>'
因此bert4torch_config.json中未设置skip_special_tokens, 默认为False
'''
def __init__(self, *args, system:str=None, **kwargs):
super().__init__(*args, **kwargs)
self.system = system if system is not None else SYSTEM_ZH
self.plugin_with_name = True

self.api_prefix = (
"This is the subfunction for tool '{tool_name}', you can use this tool. "
'The description of this function is: \n{description}')

self.meta_prompt = ('当开启工具以及代码时,根据需求选择合适的工具进行调用')

INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
'当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
'这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
'复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
'文本处理和分析(比如文本解析和自然语言处理),'
'机器学习和数据科学(用于展示模型训练和数据可视化),'
'以及文件操作和数据导入(处理CSV、JSON等格式的文件)。')

self.plugin_prompt = ('你可以使用如下工具:'
'\n{prompt}\n'
'如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
'同时注意你可以使用的工具,不要随意捏造!')


def build_prompt(self, query:str, history:List[dict], functions:List[dict]=None):
if (len(history) == 0) or (history[0]["role"] != "system"):
history.insert(0, {"role": "system", "content": self.system})
history.insert(0, {"role": "system", "content": self.system if functions is None else self.meta_prompt})

if (functions is not None) and all([h['role'] !='function' for h in history]):
# history中没有function
start = [i for i, v in enumerate(history) if v['role']=='system'][-1] + 1
plugin_descriptions = []
for i, func in enumerate(functions, start=start):
name, description = func['name'], func['description']
if not name.startswith('<|') and not name.endswith('|>'):
name = f'<|{name}|>'
if func.get('conifg', {}).get('with_name', True):
content = f"""<|im_start|>system name={name}\n{description}<|im_end|>\n"""
else:
content = f"""<|im_start|>system\n{name}\n{description}<|im_end|>\n"""
history.insert(i, {"role": "function", "content": content})
plugin = copy.deepcopy(func)
name = plugin['name'].split('.')[0]
plugin['description'] = self.api_prefix.format(tool_name=name, description=plugin['description'])
plugin_descriptions.append(plugin)

plugin_prompt = self.plugin_prompt.format(prompt=json.dumps(plugin_descriptions, ensure_ascii=False, indent=4))

if self.plugin_with_name:
content = f"""<|im_start|>system name=<|plugin|>\n{plugin_prompt}<|im_end|>\n"""
else:
content = f"""<|im_start|>system\n<|plugin|>\n{plugin_prompt}<|im_end|>\n"""
history.insert(i, {"role": "function", "content": content})

if self.tokenizer.add_bos_token:
prompt = ""
Expand Down Expand Up @@ -1271,19 +1280,12 @@ def process_response_history(self, response, history=None):

start_token = '<|action_start|>'
end_token = '<|action_end|>'
plugin_token = '<|plugin|>'
interpreter_token = '<|interpreter|>'
if plugin_token in response:
response, arguments = response.split(f"{start_token}{plugin_token}")
arguments = arguments.split(end_token)[0]
response = response.split(start_token)[0]
history[-1]['function_call'] = {"name": 'plugin', "arguments": arguments}

if interpreter_token in response: #
response, arguments = response.split(f"{start_token}{interpreter_token}")
arguments = arguments.split(end_token)[0].strip()
response = response.split(start_token)[0]
history[-1]['function_call'] = {"name": 'interpreter', "arguments": arguments}
for _token in ['<|plugin|>', '<|interpreter|>']:
if _token in response:
response, arguments = response.split(f"{start_token}{_token}")
arguments = arguments.split(end_token)[0].strip()
response = response.split(start_token)[0]
history[-1]['function_call'] = {"name": 'plugin', "arguments": arguments}

return response

Expand Down Expand Up @@ -1950,6 +1952,33 @@ def build_prompt(self, query:str, history:List[dict], functions:List[dict]=None)
return total_input


@add_start_docstrings(CHAT_START_DOCSTRING)
class PretrainedTextContinuation(ChatBase):
'''预训练的模型续写'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def build_prompt(self, query:str, history:List[dict], functions:List[dict]=None) -> str:
if functions is not None:
log_warn('PretrainedTextContinuation do not support function call')

total_input = ''
if self.no_history_states():
for query_or_response in history:
role, content = query_or_response['role'], query_or_response['content']
if self.generation_config.get('include_input', False):
if role == 'assistant':
total_input += content
else:
total_input += content
else:
total_input += [self.generation_config['states']['last_token_id']]
total_input += query

history.append({"role": "user", "content": query})
return total_input


MAPPING = {
'glm': Glm,
'glm2': Glm2,
Expand Down Expand Up @@ -2054,8 +2083,11 @@ def __new__(cls, *args, mode:Literal['cli', 'gradio', 'streamlit', 'openai']='cl
if template is None:
raise ValueError('template/model/model_type not found in bert4torch_config.json')
elif template not in MAPPING:
raise ValueError(f'template:{template} not supported')
ChatTemplate = MAPPING[template]
log_info('PretrainedTextContinuation is used, only can continue your text.')
ChatTemplate = PretrainedTextContinuation
else:
ChatTemplate = MAPPING[template]
log_info(f'Chat pipeline use template=`{template}` and mode=`{mode}`')

if mode == 'cli':
@add_start_docstrings(CHAT_START_DOCSTRING)
Expand All @@ -2071,5 +2103,4 @@ class ChatDemo(ChatTemplate, ChatWebStreamlit): pass
class ChatDemo(ChatTemplate, ChatOpenaiApi): pass
else:
raise ValueError(f'Unsupported mode={mode}')
log_info(f'Chat pipeline use template=`{template}` and mode=`{mode}`')
return ChatDemo(*args, **kwargs)
2 changes: 1 addition & 1 deletion docs/History.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## 更新历史

- **20240730**:【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug
- **20240801**:【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug
- **20240619**:增加Qwen1.5, Qwen2, glm4; 增加SWA/convert_lm_logits_dtype;调整各个trainer(重点DPOTrainer), generation中segment_ids, repetition_penalty需带query
- **20240426**:简化大模型调用demo, generation_config从config读取, 增加Qwen2和SWA, 修复RMSNorm中转类型bug
- **20240418**:修改Text2Vec的bug, 完善openai client, 增加get_weight_decay_optim_groups
Expand Down
2 changes: 1 addition & 1 deletion docs/Update.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

|更新日期| bert4torch版本 | torch4keras版本 | 版本说明 |
|------| ---------------- | ----------------- |----------- |
|20240730| 0.5.2 | 0.2.5 | 【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug|
|20240801| 0.5.2 | 0.2.5 | 【新功能】chatglm/qwen系列支持function call调用, 增加internlm2系列;【小优化】简化pipeline中chat demo的调用,generate的终止token元素允许为列表, 统一rope_scaling参数名,增加rope衍生类;【bug】修复flash_attn2的推理bug, 修复bart的tie_word_embedding的bug|
|20240619| 0.5.1 | 0.2.4 | 增加Qwen1.5, Qwen2, glm4; 增加SWA/convert_lm_logits_dtype;调整各个trainer(重点DPOTrainer), generation中segment_ids, repetition_penalty需带query, RMSNorm中转类型bug|
|20240418| 0.5.0 | 0.2.2 | 修复chatglm3的bug, 修复save_pretrained时多文件的bug,增加CausalLMLoss, 修改deepspeed的传参逻辑,修改Text2Vec的bug, 完善openai client, 增加get_weight_decay_optim_groups|
|20240317| 0.4.9.post2 | 0.2.1.post2 |增加get_weight_decay_optim_groups函数, attention中允许is_causal,修改repetition_penalty的bug,把baichuan从llama中剥离,修复config_path的bug,允许num_key_value_heads参数,[torch4keras-v0.2.1.post2](https://github.com/Tongjilibo/torch4keras/releases/tag/v0.2.1.post2)更新特性|
Expand Down
3 changes: 1 addition & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
| |[basic_language_model_llama_chinese_llama_alpaca.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/llama/basic_language_model_llama_chinese_llama_alpaca.py): 测试[chinese_llama_alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)模型。
| |[basic_language_model_llama_vicuna.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/llama/basic_language_model_llama_vicuna.py): 测试[vicuna](https://hf-mirror.com/lmsys/vicuna-7b-v1.5)模型。
| |[basic_language_model_llama_ziya.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/llama/basic_language_model_llama_ziya.py): 测试[ziya](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1)模型。
| |[basic_language_model_llama-2.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/llama/basic_language_model_llama-2.py): 测试[llama-2](https://github.com/facebookresearch/llama)模型。
| |[basic_language_model_llama.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/llama/basic_language_model_llama.py): 测试[llama](https://github.com/facebookresearch/llama)模型。
| |[basic_language_model_llama.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/llama/basic_language_model_llama.py): 测试[llama](https://github.com/facebookresearch/llama)系列模型。
| nezha|[basic_language_model_nezha_gen_gpt.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/nezha/basic_language_model_nezha_gen_gpt.py):测试[GPTBase(又叫NEZHE-GEN)](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-Gen-TensorFlow)的生成效果。
| |[basic_language_model_nezha.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/nezha/basic_language_model_nezha.py):测试[nezha](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-TensorFlow)的mlm效果。
|others|[basic_language_model_moss.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/others/basic_language_model_moss.py): 测试[moss](https://github.com/OpenLMLab/MOSS)模型, int4和int8低成本部署。
Expand Down
Loading

0 comments on commit 37f9dec

Please sign in to comment.