diff --git a/common/gin.go b/common/gin.go index bed2c2b178..b6ef96a6e9 100644 --- a/common/gin.go +++ b/common/gin.go @@ -8,12 +8,24 @@ import ( "strings" ) -func UnmarshalBodyReusable(c *gin.Context, v any) error { +const KeyRequestBody = "key_request_body" + +func GetRequestBody(c *gin.Context) ([]byte, error) { + requestBody, _ := c.Get(KeyRequestBody) + if requestBody != nil { + return requestBody.([]byte), nil + } requestBody, err := io.ReadAll(c.Request.Body) if err != nil { - return err + return nil, err } - err = c.Request.Body.Close() + _ = c.Request.Body.Close() + c.Set(KeyRequestBody, requestBody) + return requestBody.([]byte), nil +} + +func UnmarshalBodyReusable(c *gin.Context, v any) error { + requestBody, err := GetRequestBody(c) if err != nil { return err } diff --git a/controller/relay.go b/controller/relay.go index 240042b662..499e8ddc39 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,9 +1,11 @@ package controller import ( + "bytes" "context" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -13,6 +15,7 @@ import ( "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" + "io" "net/http" ) @@ -50,8 +53,8 @@ func Relay(c *gin.Context) { go processChannelRelayError(ctx, channelId, channelName, bizErr) requestId := c.GetString(logger.RequestIdKey) retryTimes := config.RetryTimes - if !shouldRetry(bizErr.StatusCode) { - logger.Errorf(ctx, "relay error happen, but status code is %d, won't retry in this case", bizErr.StatusCode) + if !shouldRetry(c, bizErr.StatusCode) { + logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) retryTimes = 0 } for i := retryTimes; i > 0; i-- { @@ -65,6 +68,8 @@ func Relay(c *gin.Context) { continue } middleware.SetupContextForSelectedChannel(c, channel, originalModel) + requestBody, err := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) bizErr = relay(c, relayMode) if bizErr == nil { return @@ -85,7 +90,10 @@ func Relay(c *gin.Context) { } } -func shouldRetry(statusCode int) bool { +func shouldRetry(c *gin.Context, statusCode int) bool { + if _, ok := c.Get("specific_channel_id"); ok { + return false + } if statusCode == http.StatusTooManyRequests { return true }