diff --git a/README.md b/README.md index d15b334c5..f80fed462 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,21 @@ # New API > [!NOTE] -> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。 -> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 +> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发 -> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。 +> [!IMPORTANT] +> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 +> 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 -> [!NOTE] -> 最新版Docker镜像 calciumion/new-api:latest -> 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR +> [!TIP] +> 最新版Docker镜像:`calciumion/new-api:latest` +> 默认账号root 密码123456 +> 更新指令: +> ``` +> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR +> ``` + ## 主要变更 此分叉版本的主要变更如下: @@ -18,9 +24,9 @@ 1. 全新的UI界面(部分界面还待更新) 2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md) 3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口: - + [x] 易支付 + + [x] 易支付 4. 支持用key查询使用额度: - + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 + + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 5. 渠道显示已使用额度,支持指定组织访问 6. 分页支持选择每页显示数量 7. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db) @@ -51,29 +57,14 @@ 您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 -## 渠道重试 -渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 -如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 -### 缓存设置方法 -1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 - + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` -2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 - + 例子:`MEMORY_CACHE_ENABLED=true` -### 为什么有的时候没有重试 -这些错误码不会重试:400,504,524 -### 我想让400也重试 -在`渠道->编辑`中,将`状态码复写`改为 -```json -{ - "400": "500" -} -``` -可以实现400错误转为500错误,从而重试 - ## 比原版One API多出的配置 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒 -- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false` -- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型 +- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` +- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true` +- `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用, +- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true` +- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度 + ## 部署 ### 部署要求 - 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机) @@ -96,8 +87,25 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai - docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest # 注意:数据库要开启远程访问,并且只允许服务器IP访问 ``` -### 默认账号密码 -默认账号root 密码123456 + +## 渠道重试 +渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 +如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 +### 缓存设置方法 +1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` +2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`MEMORY_CACHE_ENABLED=true` +### 为什么有的时候没有重试 +这些错误码不会重试:400,504,524 +### 我想让400也重试 +在`渠道->编辑`中,将`状态码复写`改为 +```json +{ + "400": "500" +} +``` +可以实现400错误转为500错误,从而重试 ## Midjourney接口设置文档 [对接文档](Midjourney.md) diff --git a/common/constants.go b/common/constants.go index 85e6eecb6..6cc7added 100644 --- a/common/constants.go +++ b/common/constants.go @@ -235,6 +235,7 @@ const ( ChannelTypeSunoAPI = 36 ChannelTypeDify = 37 ChannelTypeJina = 38 + ChannelCloudflare = 39 ChannelTypeDummy // this one is only for count, do not add any channel after this @@ -280,4 +281,5 @@ var ChannelBaseURLs = []string{ "", //36 "", //37 "https://api.jina.ai", //38 + "https://api.cloudflare.com", //39 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 2b894029e..67ae69a01 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -30,6 +30,8 @@ var defaultModelRatio = map[string]float64{ "gpt-4-32k": 30, "gpt-4-32k-0314": 30, "gpt-4-32k-0613": 30, + "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens + "gpt-4o-mini-2024-07-18": 0.075, "gpt-4o": 2.5, // $0.005 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens "gpt-4-turbo": 5, // $0.01 / 1K tokens @@ -104,12 +106,13 @@ var defaultModelRatio = map[string]float64{ "gemini-1.0-pro-latest": 1, "gemini-1.0-pro-vision-latest": 1, "gemini-ultra": 1, - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "glm-4": 7.143, // ¥0.1 / 1k tokens - "glm-4v": 7.143, // ¥0.1 / 1k tokens + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "glm-4": 7.143, // ¥0.1 / 1k tokens + "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens + "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens "glm-3-turbo": 0.3572, "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "qwen-plus": 10, // ¥0.14 / 1k tokens @@ -157,6 +160,8 @@ var defaultModelRatio = map[string]float64{ } var defaultModelPrice = map[string]float64{ + "suno_music": 0.1, + "suno_lyrics": 0.01, "dall-e-2": 0.02, "dall-e-3": 0.04, "gpt-4-gizmo-*": 0.1, @@ -313,6 +318,10 @@ func GetCompletionRatio(name string) float64 { return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" { + if strings.HasPrefix(name, "gpt-4o-mini") { + return 4 + } + if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") { return 3 } diff --git a/common/str.go b/common/str.go new file mode 100644 index 000000000..d61adb171 --- /dev/null +++ b/common/str.go @@ -0,0 +1,73 @@ +package common + +import ( + "encoding/json" + "math/rand" + "strconv" + "unsafe" +) + +func GetStringIfEmpty(str string, defaultValue string) string { + if str == "" { + return defaultValue + } + return str +} + +func GetRandomString(length int) string { + //rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func MapToJsonStr(m map[string]interface{}) string { + bytes, err := json.Marshal(m) + if err != nil { + return "" + } + return string(bytes) +} + +func MapToJsonStrFloat(m map[string]float64) string { + bytes, err := json.Marshal(m) + if err != nil { + return "" + } + return string(bytes) +} + +func StrToMap(str string) map[string]interface{} { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(str), &m) + if err != nil { + return nil + } + return m +} + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} + +func StringsContains(strs []string, str string) bool { + for _, s := range strs { + if s == str { + return true + } + } + return false +} + +// StringToByteSlice []byte only read, panic on append +func StringToByteSlice(s string) []byte { + tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) + tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} + return *(*[]byte)(unsafe.Pointer(&tmp2)) +} diff --git a/common/utils.go b/common/utils.go index 7059a78a4..9c913e4db 100644 --- a/common/utils.go +++ b/common/utils.go @@ -2,7 +2,6 @@ package common import ( "context" - "encoding/json" "errors" "fmt" "github.com/google/uuid" @@ -18,7 +17,6 @@ import ( "strconv" "strings" "time" - "unsafe" ) func OpenBrowser(url string) { @@ -164,15 +162,6 @@ func GenerateKey() string { return string(key) } -func GetRandomString(length int) string { - //rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - func GetRandomInt(max int) int { //rand.Seed(time.Now().UnixNano()) return rand.Intn(max) @@ -199,60 +188,11 @@ func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } -func String2Int(str string) int { - num, err := strconv.Atoi(str) - if err != nil { - return 0 - } - return num -} - -func StringsContains(strs []string, str string) bool { - for _, s := range strs { - if s == str { - return true - } - } - return false -} - -// StringToByteSlice []byte only read, panic on append -func StringToByteSlice(s string) []byte { - tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) - tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} - return *(*[]byte)(unsafe.Pointer(&tmp2)) -} - func RandomSleep() { // Sleep for 0-3000 ms time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) } -func MapToJsonStr(m map[string]interface{}) string { - bytes, err := json.Marshal(m) - if err != nil { - return "" - } - return string(bytes) -} - -func MapToJsonStrFloat(m map[string]float64) string { - bytes, err := json.Marshal(m) - if err != nil { - return "" - } - return string(bytes) -} - -func StrToMap(str string) map[string]interface{} { - m := make(map[string]interface{}) - err := json.Unmarshal([]byte(str), &m) - if err != nil { - return nil - } - return m -} - func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) { if "" == proxyUrl { return &http.Client{}, nil diff --git a/constant/env.go b/constant/env.go index 96483fe19..76146cacd 100644 --- a/constant/env.go +++ b/constant/env.go @@ -9,3 +9,9 @@ var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) // ForceStreamOption 覆盖请求参数,强制返回usage信息 var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) + +var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) + +var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) + +var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) diff --git a/controller/channel-test.go b/controller/channel-test.go index 000d7f2a6..fe279785a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/bytedance/gopkg/util/gopool" "io" "math" "net/http" @@ -12,6 +13,7 @@ import ( "net/url" "one-api/common" "one-api/dto" + "one-api/middleware" "one-api/model" "one-api/relay" relaycommon "one-api/relay/common" @@ -24,7 +26,7 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) { +func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { tik := time.Now() if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil @@ -40,29 +42,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr Body: nil, Header: make(http.Header), } - c.Request.Header.Set("Authorization", "Bearer "+channel.Key) - c.Request.Header.Set("Content-Type", "application/json") - c.Set("channel", channel.Type) - c.Set("base_url", channel.GetBaseURL()) - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) - } - meta := relaycommon.GenRelayInfo(c) - apiType, _ := constant.ChannelType2APIType(channel.Type) - adaptor := relay.GetAdaptor(apiType) - if adaptor == nil { - return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil - } if testModel == "" { if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel @@ -79,8 +59,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error - return err, &openaiErr + return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[testModel] != "" { testModel = modelMap[testModel] @@ -88,14 +67,28 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr } } + c.Request.Header.Set("Authorization", "Bearer "+channel.Key) + c.Request.Header.Set("Content-Type", "application/json") + c.Set("channel", channel.Type) + c.Set("base_url", channel.GetBaseURL()) + + middleware.SetupContextForSelectedChannel(c, channel, testModel) + + meta := relaycommon.GenRelayInfo(c) + apiType, _ := constant.ChannelType2APIType(channel.Type) + adaptor := relay.GetAdaptor(apiType) + if adaptor == nil { + return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil + } + request := buildTestRequest() request.Model = testModel meta.UpstreamModelName = testModel common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) - adaptor.Init(meta, *request) + adaptor.Init(meta) - convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) + convertedRequest, err := adaptor.ConvertRequest(c, meta, request) if err != nil { return err, nil } @@ -110,12 +103,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr return err, nil } if resp != nil && resp.StatusCode != http.StatusOK { - err := relaycommon.RelayErrorHandler(resp) - return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error + err := service.RelayErrorHandler(resp) + return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err } usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { - return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error + return fmt.Errorf("%s", respErr.Error.Message), respErr } if usage == nil { return errors.New("usage is nil"), nil @@ -225,11 +218,11 @@ func testAllChannels(notify bool) error { if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value } - go func() { + gopool.Go(func() { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel, "") + err, openaiWithStatusErr := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() @@ -238,27 +231,29 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) ban = true } - if openaiErr != nil { - err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message)) - ban = true + + // request error disables the channel + if openaiWithStatusErr != nil { + oaiErr := openaiWithStatusErr.Error + err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) + ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) } + // parse *int to bool if channel.AutoBan != nil && *channel.AutoBan == 0 { ban = false } - if openaiErr != nil { - openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{ - StatusCode: -1, - Error: *openaiErr, - LocalError: false, - } - if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban { - service.DisableChannel(channel.Id, channel.Name, err.Error()) - } - if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) { - service.EnableChannel(channel.Id, channel.Name) - } + + // disable channel + if ban && isChannelEnabled { + service.DisableChannel(channel.Id, channel.Name, err.Error()) + } + + // enable channel + if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) { + service.EnableChannel(channel.Id, channel.Name) } + channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } @@ -271,7 +266,7 @@ func testAllChannels(notify bool) error { common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } - }() + }) return nil } diff --git a/controller/midjourney.go b/controller/midjourney.go index 2d538f1af..464527a86 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -146,28 +146,26 @@ func UpdateMidjourneyTaskBulk() { buttonStr, _ := json.Marshal(responseItem.Buttons) task.Buttons = string(buttonStr) } - + shouldReturnQuota := false if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" - err = model.CacheUpdateUserQuota(task.UserId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } else { - quota := task.Quota - if quota != 0 { - err = model.IncreaseUserQuota(task.UserId, quota) - if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) - } + if task.Quota != 0 { + shouldReturnQuota = true } } err = task.Update() if err != nil { common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + } else { + if shouldReturnQuota { + err = model.IncreaseUserQuota(task.UserId, task.Quota) + if err != nil { + common.LogError(ctx, "fail to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } } } } diff --git a/controller/model.go b/controller/model.go index 7e3a3210d..6b4a878b9 100644 --- a/controller/model.go +++ b/controller/model.go @@ -131,7 +131,7 @@ func init() { } meta := &relaycommon.RelayInfo{ChannelType: i} adaptor := relay.GetAdaptor(apiType) - adaptor.Init(meta, dto.GeneralOpenAIRequest{}) + adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() } } diff --git a/controller/relay.go b/controller/relay.go index a04c85a2c..bc951f77e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -22,13 +22,13 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode switch relayMode { case relayconstant.RelayModeImagesGenerations: - err = relay.RelayImageHelper(c, relayMode) + err = relay.ImageHelper(c, relayMode) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: - err = relay.AudioHelper(c, relayMode) + err = relay.AudioHelper(c) case relayconstant.RelayModeRerank: err = relay.RerankHelper(c, relayMode) default: diff --git a/dto/audio.go b/dto/audio.go index c67d67857..c36b3da54 100644 --- a/dto/audio.go +++ b/dto/audio.go @@ -1,13 +1,34 @@ package dto -type TextToSpeechRequest struct { - Model string `json:"model" binding:"required"` - Input string `json:"input" binding:"required"` - Voice string `json:"voice" binding:"required"` - Speed float64 `json:"speed"` - ResponseFormat string `json:"response_format"` +type AudioRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` + Speed float64 `json:"speed,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` } type AudioResponse struct { Text string `json:"text"` } + +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} diff --git a/dto/dalle.go b/dto/dalle.go index d366051ce..d0bba655e 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -12,9 +12,11 @@ type ImageRequest struct { } type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - } + Data []ImageData `json:"data"` + Created int64 `json:"created"` +} +type ImageData struct { + Url string `json:"url"` + B64Json string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` } diff --git a/dto/text_request.go b/dto/text_request.go index c50975582..f2edf6a3b 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -29,12 +29,13 @@ type GeneralOpenAIRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` + Tools []ToolCall `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` LogitBias any `json:"logit_bias,omitempty"` LogProbs any `json:"logprobs,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"` + Dimensions int `json:"dimensions,omitempty"` } type OpenAITools struct { @@ -52,8 +53,8 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r GeneralOpenAIRequest) GetMaxTokens() int64 { - return int64(r.MaxTokens) +func (r GeneralOpenAIRequest) GetMaxTokens() int { + return int(r.MaxTokens) } func (r GeneralOpenAIRequest) ParseInput() []string { @@ -107,6 +108,11 @@ func (m Message) StringContent() string { return string(m.Content) } +func (m *Message) SetStringContent(content string) { + jsonContent, _ := json.Marshal(content) + m.Content = jsonContent +} + func (m Message) IsStringContent() bool { var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { @@ -146,7 +152,7 @@ func (m Message) ParseContent() []MediaMessage { if ok { subObj["detail"] = detail.(string) } else { - subObj["detail"] = "auto" + subObj["detail"] = "high" } contentList = append(contentList, MediaMessage{ Type: ContentTypeImageURL, @@ -155,7 +161,16 @@ func (m Message) ParseContent() []MediaMessage { Detail: subObj["detail"].(string), }, }) + } else if url, ok := contentMap["image_url"].(string); ok { + contentList = append(contentList, MediaMessage{ + Type: ContentTypeImageURL, + ImageUrl: MessageImageUrl{ + Url: url, + Detail: "high", + }, + }) } + } } return contentList diff --git a/dto/text_response.go b/dto/text_response.go index 3310d0214..9b12683c2 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct { ToolCalls []ToolCall `json:"tool_calls,omitempty"` } -func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool { - return c.Content == nil && len(c.ToolCalls) == 0 -} - func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { c.Content = &s } @@ -90,9 +86,11 @@ type ToolCall struct { } type FunctionCall struct { - Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Name string `json:"name,omitempty"` // call function with arguments in JSON format - Arguments string `json:"arguments,omitempty"` + Parameters any `json:"parameters,omitempty"` // request + Arguments string `json:"arguments,omitempty"` } type ChatCompletionsStreamResponse struct { @@ -105,6 +103,17 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string { + if c.SystemFingerprint == nil { + return "" + } + return *c.SystemFingerprint +} + +func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) { + c.SystemFingerprint = &s +} + type ChatCompletionsStreamResponseSimple struct { Choices []ChatCompletionsStreamResponseChoice `json:"choices"` Usage *Usage `json:"usage"` diff --git a/go.mod b/go.mod index 12c0de2e3..a88e23569 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect + github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect diff --git a/go.sum b/go.sum index e8cb7ba3c..e4fad3c3a 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0= +github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -205,6 +207,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -214,6 +217,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/main.go b/main.go index e929e0cb2..959b795f8 100644 --- a/main.go +++ b/main.go @@ -3,12 +3,14 @@ package main import ( "embed" "fmt" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "log" "net/http" "one-api/common" + "one-api/constant" "one-api/controller" "one-api/middleware" "one-api/model" @@ -89,11 +91,11 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } - if common.IsMasterNode { - common.SafeGoroutine(func() { + if common.IsMasterNode && constant.UpdateTask { + gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() }) - common.SafeGoroutine(func() { + gopool.Go(func() { controller.UpdateTaskBulk() }) } diff --git a/middleware/distributor.go b/middleware/distributor.go index 61361e6a6..f150b41fc 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "fmt" "net/http" "one-api/common" @@ -25,6 +26,10 @@ func Distribute() func(c *gin.Context) { var channel *model.Channel channelId, ok := c.Get("specific_channel_id") modelRequest, shouldSelectChannel, err := getModelRequest(c) + if err != nil { + abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) + return + } userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) if ok { @@ -141,7 +146,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } if err != nil { abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) - return nil, false, err + return nil, false, errors.New("无效的请求, " + err.Error()) } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if modelRequest.Model == "" { @@ -154,18 +159,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e" - } + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - if modelRequest.Model == "" { - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - modelRequest.Model = "tts-1" - } else { - modelRequest.Model = "whisper-1" - } + relayMode := relayconstant.RelayModeAudioSpeech + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") + relayMode = relayconstant.RelayModeAudioTranslation + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") + relayMode = relayconstant.RelayModeAudioTranscription } + c.Set("relay_mode", relayMode) } return &modelRequest, shouldSelectChannel, nil } @@ -198,11 +207,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) case common.ChannelTypeGemini: c.Set("api_version", channel.Other) case common.ChannelTypeAli: c.Set("plugin", channel.Other) + case common.ChannelCloudflare: + c.Set("api_version", channel.Other) } } diff --git a/model/log.go b/model/log.go index 75da845ad..1d1dc1575 100644 --- a/model/log.go +++ b/model/log.go @@ -3,6 +3,7 @@ package model import ( "context" "fmt" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" "one-api/common" "strings" @@ -87,7 +88,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke common.LogError(ctx, "failed to record log: "+err.Error()) } if common.DataExportEnabled { - common.SafeGoroutine(func() { + gopool.Go(func() { LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens) }) } diff --git a/model/utils.go b/model/utils.go index 44bfbb9e2..3905e9511 100644 --- a/model/utils.go +++ b/model/utils.go @@ -2,6 +2,7 @@ package model import ( "errors" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" "one-api/common" "sync" @@ -28,12 +29,12 @@ func init() { } func InitBatchUpdater() { - go func() { + gopool.Go(func() { for { time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) batchUpdate() } - }() + }) } func addNewRecord(type_ int, id int, value int) { diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index e222a7007..870b2b0fb 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -10,12 +10,13 @@ import ( type Adaptor interface { // Init IsStream bool - Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) - InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) + Init(info *relaycommon.RelayInfo) GetRequestURL(info *relaycommon.RelayInfo) (string, error) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error - ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) + ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) + ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) + ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index fbaf54655..ff9d5330a 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" ) @@ -15,17 +16,18 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { -} - -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { - +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl) - if info.RelayMode == constant.RelayModeEmbeddings { + var fullRequestURL string + switch info.RelayMode { + case constant.RelayModeEmbeddings: fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) + case constant.RelayModeImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + default: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) } return fullRequestURL, nil } @@ -42,22 +44,32 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - switch relayMode { + switch info.RelayMode { case constant.RelayModeEmbeddings: baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) return baiduEmbeddingRequest, nil default: - baiduRequest := requestOpenAI2Ali(*request) - return baiduRequest, nil + aliReq := requestOpenAI2Ali(*request) + return aliReq, nil } } +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + aliRequest := oaiImage2Ali(request) + return aliRequest, nil +} + func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { @@ -65,14 +77,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = aliStreamHandler(c, resp) - } else { - switch info.RelayMode { - case constant.RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - default: - err, usage = aliHandler(c, resp) + switch info.RelayMode { + case constant.RelayModeImagesGenerations: + err, usage = aliImageHandler(c, resp, info) + case constant.RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + if info.IsStream { + err, usage = openai.OaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } } return diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go index fd1f07a10..f51286ad8 100644 --- a/relay/channel/ali/dto.go +++ b/relay/channel/ali/dto.go @@ -60,13 +60,40 @@ type AliUsage struct { TotalTokens int `json:"total_tokens"` } +type TaskResult struct { + B64Image string `json:"b64_image,omitempty"` + Url string `json:"url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + type AliOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + Message string `json:"message,omitempty"` + Code string `json:"code,omitempty"` + Results []TaskResult `json:"results,omitempty"` } -type AliChatResponse struct { +type AliResponse struct { Output AliOutput `json:"output"` Usage AliUsage `json:"usage"` AliError } + +type AliImageRequest struct { + Model string `json:"model"` + Input struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + } `json:"input"` + Parameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + } `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go new file mode 100644 index 000000000..3f2705cd0 --- /dev/null +++ b/relay/channel/ali/image.go @@ -0,0 +1,177 @@ +package ali + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" + "time" +) + +func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { + var imageRequest AliImageRequest + imageRequest.Input.Prompt = request.Prompt + imageRequest.Model = request.Model + imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) + imageRequest.Parameters.N = request.N + imageRequest.ResponseFormat = request.ResponseFormat + + return &imageRequest +} + +func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) { + url := fmt.Sprintf("/api/v1/tasks/%s", taskID) + + var aliResponse AliResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + common.SysError("updateTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response AliResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + common.SysError("updateTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) { + waitSeconds := 3 + step := 0 + maxStep := 20 + + var taskResponse AliResponse + var responseBody []byte + + for { + step++ + rsp, err, body := updateTask(info, taskID, key) + responseBody = body + if err != nil { + return &taskResponse, responseBody, err + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, responseBody, nil + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, responseBody, nil + } + if step >= maxStep { + break + } + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") +} + +func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse { + imageResponse := dto.ImageResponse{ + Created: info.StartTime.Unix(), + } + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + _, b64, err := common.GetImageFromUrl(data.Url) + if err != nil { + common.LogError(c, "get_image_data_failed: "+err.Error()) + continue + } + b64Json = b64 + } else { + b64Json = data.B64Image + } + + imageResponse.Data = append(imageResponse.Data, dto.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageResponse +} + +func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse AliResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) + return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, nil +} diff --git a/relay/channel/ali/relay-ali.go b/relay/channel/ali/text.go similarity index 82% rename from relay/channel/ali/relay-ali.go rename to relay/channel/ali/text.go index 4280b1c56..aec857fad 100644 --- a/relay/channel/ali/relay-ali.go +++ b/relay/channel/ali/text.go @@ -16,34 +16,13 @@ import ( const EnableSearchModelSuffix = "-internet" -func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest { - messages := make([]AliMessage, 0, len(request.Messages)) - //prompt := "" - for i := 0; i < len(request.Messages); i++ { - message := request.Messages[i] - messages = append(messages, AliMessage{ - Content: message.StringContent(), - Role: strings.ToLower(message.Role), - }) - } - enableSearch := false - aliModel := request.Model - if strings.HasSuffix(aliModel, EnableSearchModelSuffix) { - enableSearch = true - aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) - } - return &AliChatRequest{ - Model: request.Model, - Input: AliInput{ - //Prompt: prompt, - Messages: messages, - }, - Parameters: AliParameters{ - IncrementalOutput: request.Stream, - Seed: uint64(request.Seed), - EnableSearch: enableSearch, - }, +func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { + if request.TopP >= 1 { + request.TopP = 0.999 + } else if request.TopP <= 0 { + request.TopP = 0.001 } + return &request } func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { @@ -110,7 +89,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe return &openAIEmbeddingResponse } -func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { +func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse { content, _ := json.Marshal(response.Output.Text) choice := dto.OpenAITextResponseChoice{ Index: 0, @@ -134,7 +113,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { return &fullTextResponse } -func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse { +func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString(aliResponse.Output.Text) if aliResponse.Output.FinishReason != "null" { @@ -154,18 +133,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletions func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var usage dto.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) dataChan := make(chan string) stopChan := make(chan bool) go func() { @@ -187,7 +155,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var aliResponse AliChatResponse + var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -221,7 +189,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith } func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - var aliResponse AliChatResponse + var aliResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ab1131fe1..423a91d04 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,14 +7,19 @@ import ( "io" "net/http" "one-api/relay/common" + "one-api/relay/constant" "one-api/service" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if info.IsStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") + if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { + // multipart/form-data + } else { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if info.IsStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } } } @@ -38,6 +43,29 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return resp, nil } +func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { + fullRequestURL, err := a.GetRequestURL(info) + if err != nil { + return nil, fmt.Errorf("get request url failed: %w", err) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + // set form data + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + + err = a.SetupRequestHeader(c, req, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(c, req) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { resp, err := service.GetHttpClient().Do(req) if err != nil { diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 6452392a4..44a870d8e 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -20,12 +20,17 @@ type Adaptor struct { RequestMode int } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage } else { @@ -41,7 +46,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 17f5384ea..cc0be569e 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -16,12 +16,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } @@ -99,11 +104,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - switch relayMode { + switch info.RelayMode { case constant.RelayModeEmbeddings: baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request) return baiduEmbeddingRequest, nil diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 462331852..054469591 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -21,12 +21,17 @@ type Adaptor struct { RequestMode int } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage } else { @@ -53,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go index 47f0c3bad..e2a898ec4 100644 --- a/relay/channel/claude/dto.go +++ b/relay/channel/claude/dto.go @@ -5,11 +5,18 @@ type ClaudeMetadata struct { } type ClaudeMediaMessage struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Source *ClaudeMessageSource `json:"source,omitempty"` - Usage *ClaudeUsage `json:"usage,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *ClaudeMessageSource `json:"source,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + PartialJson string `json:"partial_json,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content string `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` } type ClaudeMessageSource struct { @@ -23,6 +30,18 @@ type ClaudeMessage struct { Content any `json:"content"` } +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema InputSchema `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + type ClaudeRequest struct { Model string `json:"model"` Prompt string `json:"prompt,omitempty"` @@ -35,7 +54,9 @@ type ClaudeRequest struct { TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } type ClaudeError struct { @@ -44,24 +65,20 @@ type ClaudeError struct { } type ClaudeResponse struct { - Id string `json:"id"` - Type string `json:"type"` - Content []ClaudeMediaMessage `json:"content"` - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` - Usage ClaudeUsage `json:"usage"` - Index int `json:"index"` // stream only - Delta *ClaudeMediaMessage `json:"delta"` // stream only - Message *ClaudeResponse `json:"message"` // stream only: message_start + Id string `json:"id"` + Type string `json:"type"` + Content []ClaudeMediaMessage `json:"content"` + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error ClaudeError `json:"error"` + Usage ClaudeUsage `json:"usage"` + Index int `json:"index"` // stream only + ContentBlock *ClaudeMediaMessage `json:"content_block"` + Delta *ClaudeMediaMessage `json:"delta"` // stream only + Message *ClaudeResponse `json:"message"` // stream only: message_start } -//type ClaudeResponseChoice struct { -// Index int `json:"index"` -// Type string `json:"type"` -//} - type ClaudeUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 81f41a7f7..1e32b7510 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -8,12 +8,10 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" "strings" - "time" ) func stopReasonClaude2OpenAI(reason string) string { @@ -30,6 +28,7 @@ func stopReasonClaude2OpenAI(reason string) string { } func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { + claudeRequest := ClaudeRequest{ Model: textRequest.Model, Prompt: "", @@ -60,6 +59,22 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR } func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { + claudeTools := make([]Tool, 0, len(textRequest.Tools)) + + for _, tool := range textRequest.Tools { + if params, ok := tool.Function.Parameters.(map[string]any); ok { + claudeTools = append(claudeTools, Tool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: InputSchema{ + Type: params["type"].(string), + Properties: params["properties"], + Required: params["required"], + }, + }) + } + } + claudeRequest := ClaudeRequest{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -68,10 +83,24 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR TopP: textRequest.TopP, TopK: textRequest.TopK, Stream: textRequest.Stream, + Tools: claudeTools, } if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 } + if textRequest.Stop != nil { + // stop maybe string/array string, convert to array string + switch textRequest.Stop.(type) { + case string: + claudeRequest.StopSequences = []string{textRequest.Stop.(string)} + case []interface{}: + stopSequences := make([]string, 0) + for _, stop := range textRequest.Stop.([]interface{}) { + stopSequences = append(stopSequences, stop.(string)) + } + claudeRequest.StopSequences = stopSequences + } + } formatMessages := make([]dto.Message, 0) var lastMessage *dto.Message for i, message := range textRequest.Messages { @@ -171,6 +200,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) + tools := make([]dto.ToolCall, 0) var choice dto.ChatCompletionsStreamResponseChoice if reqMode == RequestModeCompletion { choice.Delta.SetContentString(claudeResponse.Completion) @@ -186,10 +216,33 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* choice.Delta.SetContentString("") choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_start" { - return nil, nil + if claudeResponse.ContentBlock != nil { + //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) + if claudeResponse.ContentBlock.Type == "tool_use" { + tools = append(tools, dto.ToolCall{ + ID: claudeResponse.ContentBlock.Id, + Type: "function", + Function: dto.FunctionCall{ + Name: claudeResponse.ContentBlock.Name, + Arguments: "", + }, + }) + } + } else { + return nil, nil + } } else if claudeResponse.Type == "content_block_delta" { - choice.Index = claudeResponse.Index - choice.Delta.SetContentString(claudeResponse.Delta.Text) + if claudeResponse.Delta != nil { + choice.Index = claudeResponse.Index + choice.Delta.SetContentString(claudeResponse.Delta.Text) + if claudeResponse.Delta.Type == "input_json_delta" { + tools = append(tools, dto.ToolCall{ + Function: dto.FunctionCall{ + Arguments: claudeResponse.Delta.PartialJson, + }, + }) + } + } } else if claudeResponse.Type == "message_delta" { finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason) if finishReason != "null" { @@ -205,6 +258,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if claudeUsage == nil { claudeUsage = &ClaudeUsage{} } + if len(tools) > 0 { + choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... + choice.Delta.ToolCalls = tools + } response.Choices = append(response.Choices, choice) return &response, claudeUsage @@ -217,6 +274,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope Object: "chat.completion", Created: common.GetTimestamp(), } + var responseText string + if len(claudeResponse.Content) > 0 { + responseText = claudeResponse.Content[0].Text + } + tools := make([]dto.ToolCall, 0) if reqMode == RequestModeCompletion { content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) choice := dto.OpenAITextResponseChoice{ @@ -231,20 +293,32 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope choices = append(choices, choice) } else { fullTextResponse.Id = claudeResponse.Id - for i, message := range claudeResponse.Content { - content, _ := json.Marshal(message.Text) - choice := dto.OpenAITextResponseChoice{ - Index: i, - Message: dto.Message{ - Role: "assistant", - Content: content, - }, - FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + for _, message := range claudeResponse.Content { + if message.Type == "tool_use" { + args, _ := json.Marshal(message.Input) + tools = append(tools, dto.ToolCall{ + ID: message.Id, + Type: "function", // compatible with other OpenAI derivative applications + Function: dto.FunctionCall{ + Name: message.Name, + Arguments: string(args), + }, + }) } - choices = append(choices, choice) } } - + choice := dto.OpenAITextResponseChoice{ + Index: 0, + Message: dto.Message{ + Role: "assistant", + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + choice.SetStringContent(responseText) + if len(tools) > 0 { + choice.Message.ToolCalls = tools + } + choices = append(choices, choice) fullTextResponse.Choices = choices return &fullTextResponse } @@ -256,89 +330,59 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. responseText := "" createdTime := common.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil + scanner.Split(bufio.ScanLines) + service.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + info.SetFirstResponseTime() + if len(data) < 6 || !strings.HasPrefix(data, "data:") { + continue } - if atEOF { - return len(data), data, nil + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSpace(data) + var claudeResponse ClaudeResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue } - return 0, nil, nil - }) - dataChan := make(chan string, 5) - stopChan := make(chan bool, 2) - go func() { - for scanner.Scan() { - data := scanner.Text() - if !strings.HasPrefix(data, "data: ") { - continue - } - data = strings.TrimPrefix(data, "data: ") - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } + + response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) + if response == nil { + continue } - stopChan <- true - }() - isFirst := true - service.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var claudeResponse ClaudeResponse - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } + if requestMode == RequestModeCompletion { + responseText += claudeResponse.Completion + responseId = response.Id + } else { + if claudeResponse.Type == "message_start" { + // message_start, 获取usage + responseId = claudeResponse.Message.Id + info.UpstreamModelName = claudeResponse.Message.Model + usage.PromptTokens = claudeUsage.InputTokens + } else if claudeResponse.Type == "content_block_delta" { + responseText += claudeResponse.Delta.Text + } else if claudeResponse.Type == "message_delta" { + usage.CompletionTokens = claudeUsage.OutputTokens + usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens + } else if claudeResponse.Type == "content_block_start" { - response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) - if response == nil { - return true - } - if requestMode == RequestModeCompletion { - responseText += claudeResponse.Completion - responseId = response.Id } else { - if claudeResponse.Type == "message_start" { - // message_start, 获取usage - responseId = claudeResponse.Message.Id - info.UpstreamModelName = claudeResponse.Message.Model - usage.PromptTokens = claudeUsage.InputTokens - } else if claudeResponse.Type == "content_block_delta" { - responseText += claudeResponse.Delta.Text - } else if claudeResponse.Type == "message_delta" { - usage.CompletionTokens = claudeUsage.OutputTokens - usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens - } else { - return true - } + continue } - //response.Id = responseId - response.Id = responseId - response.Created = createdTime - response.Model = info.UpstreamModelName + } + //response.Id = responseId + response.Id = responseId + response.Created = createdTime + response.Model = info.UpstreamModelName - err = service.ObjectData(c, response) - if err != nil { - common.SysError(err.Error()) - } - return true - case <-stopChan: - return false + err = service.ObjectData(c, response) + if err != nil { + common.LogError(c, "send_stream_response_failed: "+err.Error()) } - }) + } + if requestMode == RequestModeCompletion { usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { @@ -357,10 +401,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } service.Done(c) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + resp.Body.Close() return nil, usage } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go new file mode 100644 index 000000000..a518da8f0 --- /dev/null +++ b/relay/channel/cloudflare/adaptor.go @@ -0,0 +1,105 @@ +package cloudflare + +import ( + "bytes" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil + default: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch info.RelayMode { + case constant.RelayModeCompletions: + return convertCf2CompletionsRequest(*request), nil + default: + return request, nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + // 添加文件字段 + file, _, err := c.Request.FormFile("file") + if err != nil { + return nil, errors.New("file is required") + } + defer file.Close() + // 打开临时文件用于保存上传的文件内容 + requestBody := &bytes.Buffer{} + + // 将上传的文件内容复制到临时文件 + if _, err := io.Copy(requestBody, file); err != nil { + return nil, err + } + return requestBody, nil +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + switch info.RelayMode { + case constant.RelayModeEmbeddings: + fallthrough + case constant.RelayModeChatCompletions: + if info.IsStream { + err, usage = cfStreamHandler(c, resp, info) + } else { + err, usage = cfHandler(c, resp, info) + } + case constant.RelayModeAudioTranslation: + fallthrough + case constant.RelayModeAudioTranscription: + err, usage = cfSTTHandler(c, resp, info) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/cloudflare/constant.go b/relay/channel/cloudflare/constant.go new file mode 100644 index 000000000..a874685af --- /dev/null +++ b/relay/channel/cloudflare/constant.go @@ -0,0 +1,38 @@ +package cloudflare + +var ModelList = []string{ + "@cf/meta/llama-2-7b-chat-fp16", + "@cf/meta/llama-2-7b-chat-int8", + "@cf/mistral/mistral-7b-instruct-v0.1", + "@hf/thebloke/deepseek-coder-6.7b-base-awq", + "@hf/thebloke/deepseek-coder-6.7b-instruct-awq", + "@cf/deepseek-ai/deepseek-math-7b-base", + "@cf/deepseek-ai/deepseek-math-7b-instruct", + "@cf/thebloke/discolm-german-7b-v1-awq", + "@cf/tiiuae/falcon-7b-instruct", + "@cf/google/gemma-2b-it-lora", + "@hf/google/gemma-7b-it", + "@cf/google/gemma-7b-it-lora", + "@hf/nousresearch/hermes-2-pro-mistral-7b", + "@hf/thebloke/llama-2-13b-chat-awq", + "@cf/meta-llama/llama-2-7b-chat-hf-lora", + "@cf/meta/llama-3-8b-instruct", + "@hf/thebloke/llamaguard-7b-awq", + "@hf/thebloke/mistral-7b-instruct-v0.1-awq", + "@hf/mistralai/mistral-7b-instruct-v0.2", + "@cf/mistral/mistral-7b-instruct-v0.2-lora", + "@hf/thebloke/neural-chat-7b-v3-1-awq", + "@cf/openchat/openchat-3.5-0106", + "@hf/thebloke/openhermes-2.5-mistral-7b-awq", + "@cf/microsoft/phi-2", + "@cf/qwen/qwen1.5-0.5b-chat", + "@cf/qwen/qwen1.5-1.8b-chat", + "@cf/qwen/qwen1.5-14b-chat-awq", + "@cf/qwen/qwen1.5-7b-chat-awq", + "@cf/defog/sqlcoder-7b-2", + "@hf/nexusflow/starling-lm-7b-beta", + "@cf/tinyllama/tinyllama-1.1b-chat-v1.0", + "@hf/thebloke/zephyr-7b-beta-awq", +} + +var ChannelName = "cloudflare" diff --git a/relay/channel/cloudflare/dto.go b/relay/channel/cloudflare/dto.go new file mode 100644 index 000000000..2f6531c03 --- /dev/null +++ b/relay/channel/cloudflare/dto.go @@ -0,0 +1,21 @@ +package cloudflare + +import "one-api/dto" + +type CfRequest struct { + Messages []dto.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +type CfAudioResponse struct { + Result CfSTTResult `json:"result"` +} + +type CfSTTResult struct { + Text string `json:"text"` +} diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go new file mode 100644 index 000000000..69d6b8534 --- /dev/null +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -0,0 +1,156 @@ +package cloudflare + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" + "time" +) + +func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest { + p, _ := textRequest.Prompt.(string) + return &CfRequest{ + Prompt: p, + MaxTokens: textRequest.GetMaxTokens(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } +} + +func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + service.SetEventStreamHeaders(c) + id := service.GetResponseID(c) + var responseText string + isFirst := true + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + if data == "[DONE]" { + break + } + + var response dto.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) + if err != nil { + common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + continue + } + for _, choice := range response.Choices { + choice.Delta.Role = "assistant" + responseText += choice.Delta.GetContentString() + } + response.Id = id + response.Model = info.UpstreamModelName + err = service.ObjectData(c, response) + if isFirst { + isFirst = false + info.FirstResponseTime = time.Now() + } + if err != nil { + common.LogError(c, "error_rendering_stream_response: "+err.Error()) + } + } + + if err := scanner.Err(); err != nil { + common.LogError(c, "error_scanning_stream_response: "+err.Error()) + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + if info.ShouldIncludeUsage { + response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) + err := service.ObjectData(c, response) + if err != nil { + common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + } + } + service.Done(c) + + err := resp.Body.Close() + if err != nil { + common.LogError(c, "close_response_body_failed: "+err.Error()) + } + + return nil, usage +} + +func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var response dto.TextResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = info.UpstreamModelName + var responseText string + for _, choice := range response.Choices { + responseText += choice.Message.StringContent() + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + response.Usage = *usage + response.Id = service.GetResponseID(c) + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + return nil, usage +} + +func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var cfResp CfAudioResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &cfResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + audioResp := &dto.AudioResponse{ + Text: cfResp.Result.Text, + } + + jsonResponse, err := json.Marshal(audioResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + return nil, usage +} diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index b5f352126..3945774c5 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -1,6 +1,7 @@ package cohere import ( + "errors" "fmt" "github.com/gin-gonic/gin" "io" @@ -14,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -34,7 +42,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return requestOpenAI2Cohere(*request), nil } diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index fc6c44500..b2c27390f 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -7,7 +7,7 @@ type CohereRequest struct { ChatHistory []ChatHistory `json:"chat_history"` Message string `json:"message"` Stream bool `json:"stream"` - MaxTokens int64 `json:"max_tokens"` + MaxTokens int `json:"max_tokens"` } type ChatHistory struct { diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index a54b95b5e..b582da2c8 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -14,12 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -32,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 975516309..e132d2f27 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -14,10 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } // 定义一个映射,存储模型名称和对应的版本 @@ -40,7 +47,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { action := "generateContent" if info.IsStream { - action = "streamGenerateContent" + action = "streamGenerateContent?alt=sse" } return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } @@ -51,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 99ab6540a..771a616a9 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -12,9 +12,15 @@ type GeminiInlineData struct { Data string `json:"data"` } +type FunctionCall struct { + FunctionName string `json:"name"` + Arguments any `json:"args"` +} + type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` } type GeminiChatContent struct { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 5c3d60c0a..98a7236c7 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -4,18 +4,14 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/gin-gonic/gin" "io" - "log" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" "strings" - "time" - - "github.com/gin-gonic/gin" ) // Setting safety to the lowest possible values since Gemini is already powerless enough @@ -46,7 +42,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques MaxOutputTokens: textRequest.MaxTokens, }, } - if textRequest.Functions != nil { + if textRequest.Tools != nil { + functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) + for _, tool := range textRequest.Tools { + functions = append(functions, tool.Function) + } + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: functions, + }, + } + } else if textRequest.Functions != nil { geminiRequest.Tools = []GeminiChatTools{ { FunctionDeclarations: textRequest.Functions, @@ -126,6 +132,30 @@ func (g *GeminiChatResponse) GetResponseText() string { return "" } +func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall { + var toolCalls []dto.ToolCall + + item := candidate.Content.Parts[0] + if item.FunctionCall == nil { + return toolCalls + } + argsBytes, err := json.Marshal(item.FunctionCall.Arguments) + if err != nil { + //common.SysError("getToolCalls failed: " + err.Error()) + return toolCalls + } + toolCall := dto.ToolCall{ + ID: fmt.Sprintf("call_%s", common.GetUUID()), + Type: "function", + Function: dto.FunctionCall{ + Arguments: string(argsBytes), + Name: item.FunctionCall.FunctionName, + }, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), @@ -144,8 +174,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp FinishReason: relaycommon.StopFinishReason, } if len(candidate.Content.Parts) > 0 { - content, _ = json.Marshal(candidate.Content.Parts[0].Text) - choice.Message.Content = content + if candidate.Content.Parts[0].FunctionCall != nil { + choice.Message.ToolCalls = getToolCalls(&candidate) + } else { + choice.Message.SetStringContent(candidate.Content.Parts[0].Text) + } } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } @@ -154,7 +187,17 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.SetContentString(geminiResponse.GetResponseText()) + //choice.Delta.SetContentString(geminiResponse.GetResponseText()) + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + respFirst := geminiResponse.Candidates[0].Content.Parts[0] + if respFirst.FunctionCall != nil { + // function response + choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0]) + } else { + // text response + choice.Delta.SetContentString(respFirst.Text) + } + } choice.FinishReason = &relaycommon.StopFinishReason var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" @@ -165,92 +208,47 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseText := "" - responseJson := "" id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() var usage = &dto.Usage{} - dataChan := make(chan string, 5) - stopChan := make(chan bool, 2) scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil + scanner.Split(bufio.ScanLines) + + service.SetEventStreamHeaders(c) + for scanner.Scan() { + data := scanner.Text() + info.SetFirstResponseTime() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "data: ") { + continue } - if atEOF { - return len(data), data, nil + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\"") + var geminiResponse GeminiChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + common.LogError(c, "error unmarshalling stream response: "+err.Error()) + continue } - return 0, nil, nil - }) - go func() { - for scanner.Scan() { - data := scanner.Text() - responseJson += data - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "\"text\": \"") { - continue - } - data = strings.TrimPrefix(data, "\"text\": \"") - data = strings.TrimSuffix(data, "\"") - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } + + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + continue } - stopChan <- true - }() - isFirst := true - service.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - // this is used to prevent annoying \ related format bug - data = fmt.Sprintf("{\"content\": \"%s\"}", data) - type dummyStruct struct { - Content string `json:"content"` - } - var dummy dummyStruct - err := json.Unmarshal([]byte(data), &dummy) - responseText += dummy.Content - var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.SetContentString(dummy.Content) - response := dto.ChatCompletionsStreamResponse{ - Id: id, - Object: "chat.completion.chunk", - Created: createAt, - Model: info.UpstreamModelName, - Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, - } - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - return false + response.Id = id + response.Created = createAt + responseText += response.Choices[0].Delta.GetContentString() + if geminiResponse.UsageMetadata.TotalTokenCount != 0 { + usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount + usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount } - }) - var geminiChatResponses []GeminiChatResponse - err := json.Unmarshal([]byte(responseJson), &geminiChatResponses) - if err != nil { - log.Printf("cannot get gemini usage: %s", err.Error()) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - } else { - for _, response := range geminiChatResponses { - usage.PromptTokens = response.UsageMetadata.PromptTokenCount - usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount + err = service.ObjectData(c, response) + if err != nil { + common.LogError(c, err.Error()) } - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + if info.ShouldIncludeUsage { response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) err := service.ObjectData(c, response) @@ -259,10 +257,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom } } service.Done(c) - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage - } + resp.Body.Close() return nil, usage } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 48616b6c8..6a04d0881 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -15,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -36,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return request, nil } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 76de148da..408db6aae 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -10,16 +10,22 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" - "one-api/service" ) type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -36,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - switch relayMode { + switch info.RelayMode { case relayconstant.RelayModeEmbeddings: return requestOpenAI2Embeddings(*request), nil default: @@ -58,11 +64,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - } + err, usage = openai.OaiStreamHandler(c, resp, info) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 00f01fdc9..4388efd6d 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -1,10 +1,13 @@ package openai import ( + "bytes" + "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" "io" + "mime/multipart" "net/http" "one-api/common" "one-api/dto" @@ -14,22 +17,16 @@ import ( "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" - "one-api/service" + "one-api/relay/constant" "strings" ) type Adaptor struct { - ChannelType int + ChannelType int + ResponseFormat string } -func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, nil -} - -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { -} - -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType } @@ -74,28 +71,84 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } + if info.ChannelType != common.ChannelTypeOpenAI { + request.StreamOptions = nil + } + return request, nil +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + a.ResponseFormat = request.ResponseFormat + if info.RelayMode == constant.RelayModeAudioSpeech { + jsonData, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("error marshalling object: %w", err) + } + return bytes.NewReader(jsonData), nil + } else { + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + writer.WriteField("model", request.Model) + + // 添加文件字段 + file, header, err := c.Request.FormFile("file") + if err != nil { + return nil, errors.New("file is required") + } + defer file.Close() + + part, err := writer.CreateFormFile("file", header.Filename) + if err != nil { + return nil, errors.New("create form file failed") + } + if _, err := io.Copy(part, file); err != nil { + return nil, errors.New("copy file failed") + } + + // 关闭 multipart 编写器以设置分界线 + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return &requestBody, nil + } +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return request, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return channel.DoApiRequest(a, c, info, requestBody) + if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { + return channel.DoFormRequest(a, c, info, requestBody) + } else { + return channel.DoApiRequest(a, c, info, requestBody) + } } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - var responseText string - var toolCount int - err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - usage.CompletionTokens += toolCount * 7 + switch info.RelayMode { + case constant.RelayModeAudioSpeech: + err, usage = OpenaiTTSHandler(c, resp, info) + case constant.RelayModeAudioTranslation: + fallthrough + case constant.RelayModeAudioTranscription: + err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) + case constant.RelayModeImagesGenerations: + err, usage = OpenaiTTSHandler(c, resp, info) + default: + if info.IsStream { + err, usage = OaiStreamHandler(c, resp, info) + } else { + err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } - } else { - err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go index 26ba14735..5d12d3676 100644 --- a/relay/channel/openai/constant.go +++ b/relay/channel/openai/constant.go @@ -9,6 +9,7 @@ var ModelList = []string{ "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-vision-preview", "gpt-4o", "gpt-4o-2024-05-13", + "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", "text-moderation-latest", "text-moderation-stable", diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index dace39cbe..807f4b18f 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -4,6 +4,8 @@ import ( "bufio" "bytes" "encoding/json" + "fmt" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "io" "net/http" @@ -14,38 +16,36 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "strings" - "sync" "time" ) -func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) { - //checkSensitive := constant.ShouldCheckCompletionSensitive() +func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + containStreamUsage := false + responseId := "" + var createAt int64 = 0 + var systemFingerprint string + model := info.UpstreamModelName + var responseTextBuilder strings.Builder - var usage dto.Usage + var usage = &dto.Usage{} + var streamItems []string // store stream items + toolCount := 0 scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string, 5) - stopChan := make(chan bool, 2) + scanner.Split(bufio.ScanLines) + + service.SetEventStreamHeaders(c) + + ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second) + defer ticker.Stop() + + stopChan := make(chan bool) defer close(stopChan) - defer close(dataChan) - var wg sync.WaitGroup - go func() { - wg.Add(1) - defer wg.Done() - var streamItems []string // store stream items + + gopool.Go(func() { for scanner.Scan() { + info.SetFirstResponseTime() + ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format continue @@ -53,54 +53,46 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. if data[:6] != "data: " && data[:6] != "[DONE]" { continue } - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } data = data[6:] if !strings.HasPrefix(data, "[DONE]") { + err := service.StringData(c, data) + if err != nil { + common.LogError(c, "streaming error: "+err.Error()) + } streamItems = append(streamItems, data) } } - // 计算token - streamResp := "[" + strings.Join(streamItems, ",") + "]" - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - var streamResponses []dto.ChatCompletionsStreamResponseSimple - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.ChatCompletionsStreamResponseSimple - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - if streamResponse.Usage != nil { - if streamResponse.Usage.TotalTokens != 0 { - usage = *streamResponse.Usage - } - } - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.GetContentString()) - if choice.Delta.ToolCalls != nil { - if len(choice.Delta.ToolCalls) > toolCount { - toolCount = len(choice.Delta.ToolCalls) - } - for _, tool := range choice.Delta.ToolCalls { - responseTextBuilder.WriteString(tool.Function.Name) - responseTextBuilder.WriteString(tool.Function.Arguments) - } - } - } - } - } - } else { - for _, streamResponse := range streamResponses { - if streamResponse.Usage != nil { - if streamResponse.Usage.TotalTokens != 0 { - usage = *streamResponse.Usage - } + common.SafeSendBool(stopChan, true) + }) + + select { + case <-ticker.C: + // 超时处理逻辑 + common.LogError(c, "streaming timeout") + case <-stopChan: + // 正常结束 + } + + // 计算token + streamResp := "[" + strings.Join(streamItems, ",") + "]" + switch info.RelayMode { + case relayconstant.RelayModeChatCompletions: + var streamResponses []dto.ChatCompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + var streamResponse dto.ChatCompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) + if err == nil { + responseId = streamResponse.Id + createAt = streamResponse.Created + systemFingerprint = streamResponse.GetSystemFingerprint() + model = streamResponse.Model + if service.ValidUsage(streamResponse.Usage) { + usage = streamResponse.Usage + containStreamUsage = true } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -116,67 +108,69 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } } - case relayconstant.RelayModeCompletions: - var streamResponses []dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) + } else { + for _, streamResponse := range streamResponses { + responseId = streamResponse.Id + createAt = streamResponse.Created + systemFingerprint = streamResponse.GetSystemFingerprint() + model = streamResponse.Model + if service.ValidUsage(streamResponse.Usage) { + usage = streamResponse.Usage + containStreamUsage = true + } + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.GetContentString()) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > toolCount { + toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) } } } - } else { - for _, streamResponse := range streamResponses { + } + } + case relayconstant.RelayModeCompletions: + var streamResponses []dto.CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + var streamResponse dto.CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) + if err == nil { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Text) } } } - } - if len(dataChan) > 0 { - // wait data out - time.Sleep(2 * time.Second) - } - common.SafeSendBool(stopChan, true) - }() - service.SetEventStreamHeaders(c) - isFirst := true - ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second) - defer ticker.Stop() - c.Stream(func(w io.Writer) bool { - select { - case <-ticker.C: - common.LogError(c, "reading data from upstream timeout") - return false - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] + } else { + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false } - }) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount } - wg.Wait() - return nil, &usage, responseTextBuilder.String(), toolCount + + if !containStreamUsage { + usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 + } + + if info.ShouldIncludeUsage && !containStreamUsage { + response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage) + response.SetSystemFingerprint(systemFingerprint) + service.ObjectData(c, response) + } + + service.Done(c) + + resp.Body.Close() + return nil, usage } func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -213,11 +207,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - + resp.Body.Close() if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { @@ -232,3 +222,134 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model } return nil, &simpleResponse.Usage } + +func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.TotalTokens = info.PromptTokens + return nil, usage +} + +func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var audioResp dto.AudioResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &audioResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + resp.Body.Close() + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJSON(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSRT(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJSON(responseBody) + case "vtt": + text, err = getTextFromVTT(responseBody) + } + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return nil, usage +} + +func getTextFromVTT(body []byte) (string, error) { + return getTextFromSRT(body) +} + +func getTextFromVerboseJSON(body []byte) (string, error) { + var whisperResponse dto.WhisperVerboseJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSRT(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJSON(body []byte) (string, error) { + var whisperResponse dto.AudioResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 8f6dd0ae8..d8c4ffb9d 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -15,12 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -33,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 3c65b2d01..e9d07fbea 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -10,18 +10,22 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" - "one-api/service" ) type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -34,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } @@ -54,11 +58,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - } + err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index d79330e63..5811c8735 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -23,12 +23,17 @@ type Adaptor struct { Timestamp int64 } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.Action = "ChatCompletions" a.Version = "2023-09-01" a.Timestamp = common.GetTimestamp() @@ -47,7 +52,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 9852aa19b..f499bec89 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -16,12 +16,17 @@ type Adaptor struct { request *dto.GeneralOpenAIRequest } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -33,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 0893a8358..f98581fca 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -14,12 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -37,7 +42,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 5ef9d7ab8..aaf3c5dd4 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -153,18 +153,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var usage *dto.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { - return i + 2, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) dataChan := make(chan string) metaChan := make(chan string) stopChan := make(chan bool) diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 508861fc2..5e0906efe 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -10,18 +10,22 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" - "one-api/service" ) type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -35,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } @@ -55,13 +59,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - var toolCount int - err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - usage.CompletionTokens += toolCount * 7 - } + err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/zhipu_4v/constants.go b/relay/channel/zhipu_4v/constants.go index 1b0b0cc3e..3383eb3f8 100644 --- a/relay/channel/zhipu_4v/constants.go +++ b/relay/channel/zhipu_4v/constants.go @@ -1,7 +1,7 @@ package zhipu_4v var ModelList = []string{ - "glm-4", "glm-4v", "glm-3-turbo", + "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", } var ChannelName = "zhipu_4v" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 42c8381eb..564a7adb2 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -17,6 +17,7 @@ type RelayInfo struct { TokenUnlimited bool StartTime time.Time FirstResponseTime time.Time + setFirstResponse bool ApiType int IsStream bool RelayMode int @@ -68,7 +69,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { info.ApiVersion = GetAPIVersion(c) } if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || - info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini { + info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || + info.ChannelType == common.ChannelCloudflare { info.SupportStreamOptions = true } return info @@ -82,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) { info.IsStream = isStream } +func (info *RelayInfo) SetFirstResponseTime() { + if !info.setFirstResponse { + info.FirstResponseTime = time.Now() + info.setFirstResponse = true + } +} + type TaskRelayInfo struct { ChannelType int ChannelId int diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 9ef9a8b90..6daf003a6 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -1,50 +1,17 @@ package common import ( - "encoding/json" "fmt" "github.com/gin-gonic/gin" _ "image/gif" _ "image/jpeg" _ "image/png" - "io" - "net/http" "one-api/common" - "one-api/dto" - "strconv" "strings" ) var StopFinishReason = "stop" -func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { - OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - Error: dto.OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var textResponse dto.TextResponseWithError - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody) - return - } - OpenAIErrorWithStatusCode.Error = textResponse.Error - return -} - func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 0ce2657bb..6bd93c4d2 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -22,6 +22,7 @@ const ( APITypeCohere APITypeDify APITypeJina + APITypeCloudflare APITypeDummy // this one is only for count, do not add any channel after this ) @@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeDify case common.ChannelTypeJina: apiType = APITypeJina + case common.ChannelCloudflare: + apiType = APITypeCloudflare } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index ed15b08c0..a072c740c 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -13,6 +13,7 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits + RelayModeMidjourneyImagine RelayModeMidjourneyDescribe RelayModeMidjourneyBlend @@ -22,16 +23,19 @@ const ( RelayModeMidjourneyTaskFetch RelayModeMidjourneyTaskImageSeed RelayModeMidjourneyTaskFetchByCondition - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation RelayModeMidjourneyAction RelayModeMidjourneyModal RelayModeMidjourneyShorten RelayModeSwapFace + + RelayModeAudioSpeech // tts + RelayModeAudioTranscription // whisper + RelayModeAudioTranslation // whisper + RelayModeSunoFetch RelayModeSunoFetchByID RelayModeSunoSubmit + RelayModeRerank ) diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 913772122..b2fadcc34 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -1,13 +1,10 @@ package relay import ( - "bytes" - "context" "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" - "io" "net/http" "one-api/common" "one-api/constant" @@ -16,69 +13,71 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/service" - "strings" - "time" ) -func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - startTime := time.Now() - - var audioRequest dto.TextToSpeechRequest - if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - err := common.UnmarshalBodyReusable(c, &audioRequest) - if err != nil { - return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) +func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { + audioRequest := &dto.AudioRequest{} + err := common.UnmarshalBodyReusable(c, audioRequest) + if err != nil { + return nil, err + } + switch info.RelayMode { + case relayconstant.RelayModeAudioSpeech: + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + if constant.ShouldCheckPromptSensitive() { + err := service.CheckSensitiveInput(audioRequest.Input) + if err != nil { + return nil, err + } } - } else { - audioRequest = dto.TextToSpeechRequest{ - Model: "whisper-1", + default: + if audioRequest.Model == "" { + audioRequest.Model = c.PostForm("model") + } + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + if audioRequest.ResponseFormat == "" { + audioRequest.ResponseFormat = "json" } } - //err := common.UnmarshalBodyReusable(c, &audioRequest) + return audioRequest, nil +} - // request validation - if audioRequest.Model == "" { - return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) - } +func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { + relayInfo := relaycommon.GenRelayInfo(c) + audioRequest, err := getAndValidAudioRequest(c, relayInfo) - if strings.HasPrefix(audioRequest.Model, "tts-1") { - if audioRequest.Voice == "" { - return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) - } + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest) } - var err error + promptTokens := 0 preConsumedTokens := common.PreConsumedQuota - if strings.HasPrefix(audioRequest.Model, "tts-1") { - if constant.ShouldCheckPromptSensitive() { - err = service.CheckSensitiveInput(audioRequest.Input) - if err != nil { - return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) - } - } + if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } preConsumedTokens = promptTokens + relayInfo.PromptTokens = promptTokens } + modelRatio := common.GetModelRatio(audioRequest.Model) - groupRatio := common.GetGroupRatio(group) + groupRatio := common.GetGroupRatio(relayInfo.Group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) } @@ -88,28 +87,12 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { preConsumedQuota = 0 } if preConsumedQuota > 0 { - userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) if err != nil { return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } - succeed := false - defer func() { - if succeed { - return - } - if preConsumedQuota > 0 { - // we need to roll back the pre-consumed quota - defer func() { - go func() { - // negative means add quota back for token & user - returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota) - }() - }() - } - }() - // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" { @@ -122,133 +105,44 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { audioRequest.Model = modelMap[audioRequest.Model] } } + relayInfo.UpstreamModelName = audioRequest.Model - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } + adaptor.Init(relayInfo) - fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) - if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := relaycommon.GetAPIVersion(c) - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion) - } - - requestBody := c.Request.Body - - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } - if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - req.Header.Set("api-key", apiKey) - req.ContentLength = c.Request.ContentLength - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - } - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := service.GetHttpClient().Do(req) + resp, err := adaptor.DoRequest(c, relayInfo, ioReader) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - - if resp.StatusCode != http.StatusOK { - return relaycommon.RelayErrorHandler(resp) - } - succeed = true - - var audioResponse dto.AudioResponse - - defer func(ctx context.Context) { - go func() { - useTimeSeconds := time.Now().Unix() - startTime.Unix() - quota := 0 - if strings.HasPrefix(audioRequest.Model, "tts-1") { - quota = promptTokens - } else { - quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model) - } - quota = int(float64(quota) * ratio) - if ratio != 0 && quota <= 0 { - quota = 1 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - other := make(map[string]interface{}) - other["model_ratio"] = modelRatio - other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - if strings.HasPrefix(audioRequest.Model, "tts-1") { - - } else { - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - contains, words := service.SensitiveWordContains(audioResponse.Text) - if contains { - return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest) + statusCodeMappingStr := c.GetString("status_code_mapping") + if resp != nil { + if resp.StatusCode != http.StatusOK { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) + usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } + postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") + return nil } diff --git a/relay/relay-image.go b/relay/relay-image.go index d83ec269c..f6a2641bb 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "context" "encoding/json" "errors" "fmt" @@ -14,72 +13,71 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/service" "strings" - "time" ) -func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - startTime := time.Now() - - var imageRequest dto.ImageRequest - err := common.UnmarshalBodyReusable(c, &imageRequest) +func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { + imageRequest := &dto.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + return nil, err } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-3" + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") } if imageRequest.N == 0 { imageRequest.N = 1 } - // Prompt validation - if imageRequest.Prompt == "" { - return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" } - - if constant.ShouldCheckPromptSensitive() { - err = service.CheckSensitiveInput(imageRequest.Prompt) - if err != nil { - return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) - } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" } - - if strings.Contains(imageRequest.Size, "×") { - return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest) + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" } // Not "256x256", "512x512", or "1024x1024" if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") } } else if imageRequest.Model == "dall-e-3" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") } - if imageRequest.N != 1 { - return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest) + //if imageRequest.N != 1 { + // return nil, errors.New("n must be 1") + //} + } + // N should between 1 and 10 + //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { + // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + //} + if constant.ShouldCheckPromptSensitive() { + err := service.CheckSensitiveInput(imageRequest.Prompt) + if err != nil { + return nil, err } } + return imageRequest, nil +} - // N should between 1 and 10 - if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) +func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { + relayInfo := relaycommon.GenRelayInfo(c) + + imageRequest, err := getAndValidImageRequest(c, relayInfo) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name modelMapping := c.GetString("model_mapping") - isModelMapped := false if modelMapping != "" { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) @@ -88,31 +86,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } if modelMap[imageRequest.Model] != "" { imageRequest.Model = modelMap[imageRequest.Model] - isModelMapped = true - } - } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := relaycommon.GetAPIVersion(c) - // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion) - } - var requestBody io.Reader - if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body - jsonStr, err := json.Marshal(imageRequest) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body } + relayInfo.UpstreamModelName = imageRequest.Model modelPrice, success := common.GetModelPrice(imageRequest.Model, true) if !success { @@ -121,8 +97,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC // per 1 modelRatio = $0.04 / 16 modelPrice = 0.0025 * modelRatio } - groupRatio := common.GetGroupRatio(group) - userQuota, err := model.CacheGetUserQuota(userId) + + groupRatio := common.GetGroupRatio(relayInfo.Group) + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) sizeRatio := 1.0 // Size @@ -150,98 +127,60 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } + adaptor.Init(relayInfo) - token := c.Request.Header.Get("Authorization") - if channelType == common.ChannelTypeAzure { // Azure authentication - token = strings.TrimPrefix(token, "Bearer ") - req.Header.Set("api-key", token) - } else { - req.Header.Set("Authorization", token) - } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) + var requestBody io.Reader - resp, err := service.GetHttpClient().Do(req) + convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() + jsonData, err := json.Marshal(convertedRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } + requestBody = bytes.NewBuffer(jsonData) - if resp.StatusCode != http.StatusOK { - return relaycommon.RelayErrorHandler(resp) + statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - var textResponse dto.ImageResponse - defer func(ctx context.Context) { - useTimeSeconds := time.Now().Unix() - startTime.Unix() + if resp != nil { + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { - return - } - err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } - if quota != 0 { - tokenName := c.GetString("token_name") - quality := "normal" - if imageRequest.Quality == "hd" { - quality = "hd" - } - logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality) - other := make(map[string]interface{}) - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) + _, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + usage := &dto.Usage{ + PromptTokens: imageRequest.N, + TotalTokens: imageRequest.N, } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + + quality := "standard" + if imageRequest.Quality == "hd" { + quality = "hd" } + + logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) + postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent) + return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index 79df191eb..40610acb7 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } } relayInfo.UpstreamModelName = textRequest.Model - modelPrice, success := common.GetModelPrice(textRequest.Model, false) + modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false) groupRatio := common.GetGroupRatio(relayInfo.Group) var preConsumedQuota int @@ -112,7 +112,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } - if !success { + if !getModelPriceSuccess { preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + int(textRequest.MaxTokens) @@ -150,10 +150,10 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } - adaptor.Init(relayInfo, *textRequest) + adaptor.Init(relayInfo) var requestBody io.Reader - convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest) + convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } @@ -187,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) + postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") return nil } @@ -288,7 +288,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool) { + modelPrice float64, usePrice bool, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -309,7 +309,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN } totalTokens := promptTokens + completionTokens var logContent string - if modelPrice == -1 { + if !usePrice { logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) } else { logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) @@ -350,6 +350,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN logModel = "g-*" logContent += fmt.Sprintf(",模型 %s", modelName) } + if extraContent != "" { + logContent += ", " + extraContent + } other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 8998540db..4c0aef183 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -7,6 +7,7 @@ import ( "one-api/relay/channel/aws" "one-api/relay/channel/baidu" "one-api/relay/channel/claude" + "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" @@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &dify.Adaptor{} case constant.APITypeJina: return &jina.Adaptor{} + case constant.APITypeCloudflare: + return &cloudflare.Adaptor{} } return nil } diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index e32ca8833..9885fd3ec 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -66,7 +66,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } - adaptor.InitRerank(relayInfo, *rerankRequest) + adaptor.Init(relayInfo) convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) if err != nil { @@ -99,6 +99,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) + postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") return nil } diff --git a/service/channel.go b/service/channel.go index 76be27100..5716a6d33 100644 --- a/service/channel.go +++ b/service/channel.go @@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus return false } -func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool { +func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool { if !common.AutomaticEnableChannelEnabled { return false } if err != nil { return false } - if openAIErr != nil { + if openaiWithStatusErr != nil { return false } if status != common.ChannelStatusAutoDisabled { diff --git a/service/error.go b/service/error.go index 0f6d472fb..3410de81d 100644 --- a/service/error.go +++ b/service/error.go @@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, Error: dto.OpenAIError{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), }, } responseBody, err := io.ReadAll(resp.Body) diff --git a/service/sse.go b/service/relay.go similarity index 58% rename from service/sse.go rename to service/relay.go index 2d531a4e2..03b005c3d 100644 --- a/service/sse.go +++ b/service/relay.go @@ -2,10 +2,11 @@ package service import ( "encoding/json" + "errors" "fmt" "github.com/gin-gonic/gin" + "net/http" "one-api/common" - "strings" ) func SetEventStreamHeaders(c *gin.Context) { @@ -16,11 +17,16 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } -func StringData(c *gin.Context, str string) { - str = strings.TrimPrefix(str, "data: ") - str = strings.TrimSuffix(str, "\r") +func StringData(c *gin.Context, str string) error { + //str = strings.TrimPrefix(str, "data: ") + //str = strings.TrimSuffix(str, "\r") c.Render(-1, common.CustomEvent{Data: "data: " + str}) - c.Writer.Flush() + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } else { + return errors.New("streaming error: flusher not found") + } + return nil } func ObjectData(c *gin.Context, object interface{}) error { @@ -28,10 +34,14 @@ func ObjectData(c *gin.Context, object interface{}) error { if err != nil { return fmt.Errorf("error marshalling object: %w", err) } - StringData(c, string(jsonData)) - return nil + return StringData(c, string(jsonData)) } func Done(c *gin.Context) { - StringData(c, "[DONE]") + _ = StringData(c, "[DONE]") +} + +func GetResponseID(c *gin.Context) string { + logID := c.GetString("X-Oneapi-Request-Id") + return fmt.Sprintf("chatcmpl-%s", logID) } diff --git a/service/token_counter.go b/service/token_counter.go index 5189a2a97..acf35f7b9 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -9,6 +9,7 @@ import ( "log" "math" "one-api/common" + "one-api/constant" "one-api/dto" "strings" "unicode/utf8" @@ -71,13 +72,20 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { } func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { - // TODO: 非流模式下不计算图片token数量 if model == "glm-4v" { return 1047, nil } if imageUrl.Detail == "low" { return 85, nil } + // TODO: 非流模式下不计算图片token数量 + if !constant.GetMediaTokenNotStream && !stream { + return 1000, nil + } + // 是否统计图片token + if !constant.GetMediaToken { + return 1000, nil + } // 同步One API的图片计费逻辑 if imageUrl.Detail == "auto" || imageUrl.Detail == "" { imageUrl.Detail = "high" diff --git a/service/usage_helpr.go b/service/usage_helpr.go index 528f3d48d..adec566da 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d Usage: &usage, } } + +func ValidUsage(usage *dto.Usage) bool { + return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0) +} diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 92a0a17a8..d578c15e7 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -367,7 +367,7 @@ const LogsTable = () => { dataIndex: 'content', render: (text, record, index) => { let other = getLogOther(record.other); - if (other == null) { + if (other == null || record.type !== 2) { return ( { const [inputs, setInputs] = useState({ @@ -46,9 +40,7 @@ const RegisterForm = () => { let navigate = useNavigate(); - function handleChange(e) { - const { name, value } = e.target; - console.log(name, value); + function handleChange(name, value) { setInputs((inputs) => ({ ...inputs, [name]: value })); } @@ -108,96 +100,116 @@ const RegisterForm = () => { }; return ( - - -
- 新用户注册 -
-
- - - - - {showEmailVerification ? ( - <> - - 获取验证码 - - } +
+ + + +
+
+ + + 新用户注册 + + + handleChange('username', value)} + /> + handleChange('password', value)} + /> + handleChange('password2', value)} + /> + {showEmailVerification ? ( + <> + handleChange('email', value)} + name='email' + type='email' + suffix={ + + } + /> + + handleChange('verification_code', value) + } + name='verification_code' + /> + + ) : ( + <> + )} + + +
+ + 已有账户? + 点击登录 + +
+
+ {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} /> - - - ) : ( - <> - )} - {turnstileEnabled ? ( - { - setTurnstileToken(token); - }} - /> - ) : ( - <> - )} - - - - - 已有账户? - - 点击登录 - - - - + ) : ( + <> + )} +
+
+
+
+
); }; diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index ff1d281c2..816d1732e 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -99,6 +99,13 @@ export const CHANNEL_OPTIONS = [ color: 'orange', label: 'Google PaLM2', }, + { + key: 39, + text: 'Cloudflare', + value: 39, + color: 'grey', + label: 'Cloudflare', + }, { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' }, { key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 826f3d17d..d3c70b647 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -601,6 +601,24 @@ const EditChannel = (props) => { /> )} + {inputs.type === 39 && ( + <> +
+ Account ID: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )}
模型: