From 07d17b43dc1c07dec4200a2e1d91decc9a3619f0 Mon Sep 17 00:00:00 2001 From: Eero Date: Sun, 18 Feb 2024 22:54:36 +0800 Subject: [PATCH] baichuan is ready --- README.md | 5 ++-- src/delibird/router/baichuan.py | 19 +++++++++++++ src/delibird/router/base.py | 6 ++++ src/delibird/router/chatglm.py | 25 ++--------------- src/delibird/router/common.py | 49 +++++++++++++++++++++++++++++++++ src/delibird/router/gateway.py | 3 +- tests/test_client.py | 1 + 7 files changed, 82 insertions(+), 26 deletions(-) create mode 100644 src/delibird/router/baichuan.py diff --git a/README.md b/README.md index d8262c4..2582119 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,11 @@ Delibird 是一个多合一大模型接口网关。主要针对国内的大模 ## 支持模型列表 - [通义千问](https://dashscope.console.aliyun.com/model) -- [文心大模型](https://cloud.baidu.com/product/wenxinworkshop) -- [星火大模型](https://xinghuo.xfyun.cn/sparkapi) +- [文心](https://cloud.baidu.com/product/wenxinworkshop) +- [星火](https://xinghuo.xfyun.cn/sparkapi) - [Minimax](https://api.minimax.chat/) - [ChatGLM](https://open.bigmodel.cn/dev/api) +- [百川](https://www.baichuan-ai.com/home) ## 未来计划 - [ ] function calling 支持 diff --git a/src/delibird/router/baichuan.py b/src/delibird/router/baichuan.py new file mode 100644 index 0000000..b9e78b6 --- /dev/null +++ b/src/delibird/router/baichuan.py @@ -0,0 +1,19 @@ +from .base import Base +from .common import common_decode + + +class Baichuan(Base): + def __init__(self): + self.name = "baichuan" + + async def send(self, messages, model): + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + } + + # 调用父类的 send 方法 + async for data in super().send( + messages, model, headers=headers, filter_func=common_decode + ): + yield data diff --git a/src/delibird/router/base.py b/src/delibird/router/base.py index ded9cec..0d1204c 100644 --- a/src/delibird/router/base.py +++ b/src/delibird/router/base.py @@ -157,6 +157,12 @@ async def _http_send( # 去掉结尾标记,获取数据,跳出 if snippet_data.endswith(end_tag): snippet_data = snippet_data[: -len(end_tag)] + + # 循环检查一遍末尾是否有结束标记,有就去掉 + # 为了避免有些服务返回结束标记,然后在最后一条也有 finish_reason + while snippet_data.endswith(end_tag): + snippet_data = snippet_data[: -len(end_tag)] + output += snippet_data break diff --git a/src/delibird/router/chatglm.py b/src/delibird/router/chatglm.py index f645fab..d2fb0d3 100644 --- a/src/delibird/router/chatglm.py +++ b/src/delibird/router/chatglm.py @@ -4,7 +4,7 @@ import jwt import json from delibird.log import Log -from .common import decode_data +from .common import decode_data, common_decode def generate_token(apikey: str, exp_seconds: int): @@ -37,32 +37,11 @@ async def send(self, messages, model): messages: 请求参数。格式为 [ {"role": "user", "content": "Python 如何实现异步编程"}] model: 对应的模型名称。格式为例如 qwen 就是 qwen-max、qwen-min、qwen-speed、qwen-turbo """ - logger = Log("delibird") - - self.model = model - # 拼接 header,增加 Authorization headers = {"Authorization": "Bearer " + generate_token(self.api_key, 3600)} - # 返回的数据可能会有多个,所以使用 buffer 存储 - buffer = "" - # 调用父类的 send 方法 async for data in super().send( - messages, model, headers=headers, filter_func=_decode_data + messages, model, headers=headers, filter_func=common_decode ): yield data - - -def _decode_data(data): - """解析数据.""" - - result, data = decode_data(data) - - if not result: - return (False, data) - - try: - return (True, data["choices"][0]["delta"]["content"]) - except KeyError as e: - return (False, "数据格式错误") diff --git a/src/delibird/router/common.py b/src/delibird/router/common.py index c0dd5f7..af3d920 100644 --- a/src/delibird/router/common.py +++ b/src/delibird/router/common.py @@ -37,3 +37,52 @@ def decode_data(data, start="data:", last="\n\n", end_tag="[DONE]"): return True, data except json.JSONDecodeError as e: return False, {} + + +def common_decode(data): + """通用解析函数. + + 解析从大模型获取的数据,返回其中的 content 字段。 + 数据一般是这样的: + 1. 从 choices 中获取 delta,然后从 delta 中获取 content + 2. choices 有个 finish_reason 字段表示该条内容是最后一条 + 3. [DONE] 表示结束。这个是可选,有些没有结束标记 + 4. 一段可解析内容开始的标记是 data:,结束的标记是 \n\n + + Args: + data: 待解析的数据 + Returns: + True,message: 是否成功,消息内容(str) + """ + + start = "data:" + last = "\n\n" + end_tag = "[DONE]" + + if not data.startswith(start): + return False, "" + + # 去掉开头的 data: 字符串 + data = data.lstrip(start) + + # 去掉结尾的标记字符 + data = data.rstrip(last) + + # 检查是否是结束标记 + if data == end_tag: + return True, data + + # 将 json 字符串转换为字典 + try: + data = json.loads(data) + except json.JSONDecodeError as e: + return False, "" + + # 检查 choices 下面是否有 finish_reason 字段 + # 如果有,说明是最后一条消息。在返回的消息后面加上 [DONE] + # 让调用者知道已经结束了 + if "choices" in data and "finish_reason" in data: + return True, data["choices"][0]["delta"]["content"] + end_tag + + # 返回 choices 下面的 delta 下的 content 字段,就是消息内容 + return True, data["choices"][0]["delta"]["content"] diff --git a/src/delibird/router/gateway.py b/src/delibird/router/gateway.py index 5714368..4052898 100644 --- a/src/delibird/router/gateway.py +++ b/src/delibird/router/gateway.py @@ -8,6 +8,7 @@ from .minimax import Minimax from .spark import Spark from .chatglm import Chatglm +from .baichuan import Baichuan import sys @@ -109,7 +110,7 @@ def _driver_config(self, driver_name): def _check_drivers(self): """检查读取的规则是否正确. - routers 对应的 drives 是否存在. + routers 对应的 drivers 是否存在. """ for router in self.__routers: diff --git a/tests/test_client.py b/tests/test_client.py index 674888f..b6794bb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -15,6 +15,7 @@ "model": "ernie-bot-turbo", }, {"name": "minimax", "model": "abab5.5-chat"}, + {"name": "baichuan", "model": "Baichuan2-Turbo"}, ]