Skip to content

Commit

Permalink
baichuan is ready
Browse files Browse the repository at this point in the history
  • Loading branch information
Eero committed Feb 18, 2024
1 parent a4d628b commit 07d17b4
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 26 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ Delibird 是一个多合一大模型接口网关。主要针对国内的大模
## 支持模型列表

- [通义千问](https://dashscope.console.aliyun.com/model)
- [文心大模型](https://cloud.baidu.com/product/wenxinworkshop)
- [星火大模型](https://xinghuo.xfyun.cn/sparkapi)
- [文心](https://cloud.baidu.com/product/wenxinworkshop)
- [星火](https://xinghuo.xfyun.cn/sparkapi)
- [Minimax](https://api.minimax.chat/)
- [ChatGLM](https://open.bigmodel.cn/dev/api)
- [百川](https://www.baichuan-ai.com/home)

## 未来计划
- [ ] function calling 支持
Expand Down
19 changes: 19 additions & 0 deletions src/delibird/router/baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .base import Base
from .common import common_decode


class Baichuan(Base):
def __init__(self):
self.name = "baichuan"

async def send(self, messages, model):
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
}

# 调用父类的 send 方法
async for data in super().send(
messages, model, headers=headers, filter_func=common_decode
):
yield data
6 changes: 6 additions & 0 deletions src/delibird/router/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ async def _http_send(
# 去掉结尾标记,获取数据,跳出
if snippet_data.endswith(end_tag):
snippet_data = snippet_data[: -len(end_tag)]

# 循环检查一遍末尾是否有结束标记,有就去掉
# 为了避免有些服务返回结束标记,然后在最后一条也有 finish_reason
while snippet_data.endswith(end_tag):
snippet_data = snippet_data[: -len(end_tag)]

output += snippet_data
break

Expand Down
25 changes: 2 additions & 23 deletions src/delibird/router/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jwt
import json
from delibird.log import Log
from .common import decode_data
from .common import decode_data, common_decode


def generate_token(apikey: str, exp_seconds: int):
Expand Down Expand Up @@ -37,32 +37,11 @@ async def send(self, messages, model):
messages: 请求参数。格式为 [ {"role": "user", "content": "Python 如何实现异步编程"}]
model: 对应的模型名称。格式为例如 qwen 就是 qwen-max、qwen-min、qwen-speed、qwen-turbo
"""
logger = Log("delibird")

self.model = model

# 拼接 header,增加 Authorization
headers = {"Authorization": "Bearer " + generate_token(self.api_key, 3600)}

# 返回的数据可能会有多个,所以使用 buffer 存储
buffer = ""

# 调用父类的 send 方法
async for data in super().send(
messages, model, headers=headers, filter_func=_decode_data
messages, model, headers=headers, filter_func=common_decode
):
yield data


def _decode_data(data):
"""解析数据."""

result, data = decode_data(data)

if not result:
return (False, data)

try:
return (True, data["choices"][0]["delta"]["content"])
except KeyError as e:
return (False, "数据格式错误")
49 changes: 49 additions & 0 deletions src/delibird/router/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,52 @@ def decode_data(data, start="data:", last="\n\n", end_tag="[DONE]"):
return True, data
except json.JSONDecodeError as e:
return False, {}


def common_decode(data):
"""通用解析函数.
解析从大模型获取的数据,返回其中的 content 字段。
数据一般是这样的:
1. 从 choices 中获取 delta,然后从 delta 中获取 content
2. choices 有个 finish_reason 字段表示该条内容是最后一条
3. [DONE] 表示结束。这个是可选,有些没有结束标记
4. 一段可解析内容开始的标记是 data:,结束的标记是 \n\n
Args:
data: 待解析的数据
Returns:
True,message: 是否成功,消息内容(str)
"""

start = "data:"
last = "\n\n"
end_tag = "[DONE]"

if not data.startswith(start):
return False, ""

# 去掉开头的 data: 字符串
data = data.lstrip(start)

# 去掉结尾的标记字符
data = data.rstrip(last)

# 检查是否是结束标记
if data == end_tag:
return True, data

# 将 json 字符串转换为字典
try:
data = json.loads(data)
except json.JSONDecodeError as e:
return False, ""

# 检查 choices 下面是否有 finish_reason 字段
# 如果有,说明是最后一条消息。在返回的消息后面加上 [DONE]
# 让调用者知道已经结束了
if "choices" in data and "finish_reason" in data:
return True, data["choices"][0]["delta"]["content"] + end_tag

# 返回 choices 下面的 delta 下的 content 字段,就是消息内容
return True, data["choices"][0]["delta"]["content"]
3 changes: 2 additions & 1 deletion src/delibird/router/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .minimax import Minimax
from .spark import Spark
from .chatglm import Chatglm
from .baichuan import Baichuan

import sys

Expand Down Expand Up @@ -109,7 +110,7 @@ def _driver_config(self, driver_name):
def _check_drivers(self):
"""检查读取的规则是否正确.
routers 对应的 drives 是否存在.
routers 对应的 drivers 是否存在.
"""

for router in self.__routers:
Expand Down
1 change: 1 addition & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"model": "ernie-bot-turbo",
},
{"name": "minimax", "model": "abab5.5-chat"},
{"name": "baichuan", "model": "Baichuan2-Turbo"},
]


Expand Down

0 comments on commit 07d17b4

Please sign in to comment.