diff --git a/.gitignore b/.gitignore index 2a8ae16e82..2158e42996 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ build logs data /web/node_modules -cmd.md \ No newline at end of file +cmd.md +vendor/* diff --git a/common/audit/audit.go b/common/audit/audit.go new file mode 100644 index 0000000000..2af7d5f83e --- /dev/null +++ b/common/audit/audit.go @@ -0,0 +1,27 @@ +package audit + +import ( + "github.com/sirupsen/logrus" + "gopkg.in/natefinch/lumberjack.v2" +) + +var ( + loger *lumberjack.Logger + logger *logrus.Logger +) + +func init() { + loger = &lumberjack.Logger{ + Filename: "logs/audit.log", + MaxSize: 50, // megabytes + MaxBackups: 300, + MaxAge: 90, // days + } + logger = logrus.New() + logger.SetOutput(loger) + logger.SetFormatter(&logrus.JSONFormatter{}) +} + +func Logger() *logrus.Logger { + return logger +} diff --git a/common/audit/response.go b/common/audit/response.go new file mode 100644 index 0000000000..c3cd64fcf8 --- /dev/null +++ b/common/audit/response.go @@ -0,0 +1,79 @@ +package audit + +import ( + "bytes" + "encoding/base64" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +type AuditLogger struct { + gin.ResponseWriter + buf *bytes.Buffer +} + +func (l *AuditLogger) Write(p []byte) (int, error) { + l.buf.Write(p) + return l.ResponseWriter.Write(p) +} + +func CaptureResponseBody(c *gin.Context) *bytes.Buffer { + al := &AuditLogger{ + ResponseWriter: c.Writer, + buf: &bytes.Buffer{}, + } + c.Writer = al + return al.buf +} + +func B64encode(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +type AuditReadCloser struct { + Reader io.Reader + Closer io.Closer + Buffer *bytes.Buffer +} + +func (arc *AuditReadCloser) Read(p []byte) (int, error) { + n, err := arc.Reader.Read(p) + if n > 0 { + arc.Buffer.Write(p[:n]) + } + return n, err +} + +func (arc *AuditReadCloser) Close() error { + return arc.Closer.Close() +} + +func CaptureHTTPResponseBody(resp *http.Response) *bytes.Buffer { + buf := &bytes.Buffer{} + arc := &AuditReadCloser{ + Reader: resp.Body, + Closer: resp.Body, + Buffer: buf, + } + resp.Body = arc + return buf +} + +func ParseOPENAIStreamResponse(buf *bytes.Buffer) string { + lines := strings.Split(buf.String(), "\n") + bts := []string{} + for _, line := range lines { + line = strings.TrimSpace(line) + line = strings.Trim(line, "\n") + if strings.HasPrefix(string(line), "data:") { + line = line[5:] + } + content := gjson.Get(line, "choices.0.delta.content").String() + bts = append(bts, content) + } + return strings.Join(bts, "") +} diff --git a/common/config/config.go b/common/config/config.go index 4f1c25b676..9423f37a9a 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -1,13 +1,14 @@ package config import ( - "github.com/songquanpeng/one-api/common/env" "os" "strconv" "strings" "sync" "time" + "github.com/songquanpeng/one-api/common/env" + "github.com/google/uuid" ) @@ -55,6 +56,8 @@ var EmailDomainWhitelist = []string{ var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true" var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true" var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true" +var ClientAuditEnabled = env.Bool("CLIENT_AUDIT_ENABLED", false) +var UpstreamAuditEnabled = env.Bool("UPSTREAM_AUDIT_ENABLED", false) var LogConsumeEnabled = true @@ -135,6 +138,7 @@ var ( var RateLimitKeyExpirationDuration = 20 * time.Minute +var EnableBilling = env.Bool("ENABLE_BILLING", true) var EnableMetric = env.Bool("ENABLE_METRIC", false) var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) diff --git a/controller/relay.go b/controller/relay.go index 5d8ac69039..77cf9df624 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/audit" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" @@ -17,48 +18,97 @@ import ( dbmodel "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" ) // https://platform.openai.com/docs/api-reference/chat -func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { +type Options struct { + Debug bool + EnableMonitor bool + EnableBilling bool +} + +type RelayController struct { + opts Options + controller.RelayInstance + monitor.MonitorInstance +} + +func NewRelayController(opts Options) *RelayController { + ctrl := &RelayController{ + opts: opts, + } + ctrl.RelayInstance = controller.NewRelayInstance(controller.Options{ + EnableBilling: opts.EnableBilling, + }) + if opts.EnableMonitor { + ctrl.MonitorInstance = monitor.NewMonitorInstance() + } + return ctrl +} + +func (ctrl *RelayController) relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { + if config.ClientAuditEnabled { + buf := audit.CaptureResponseBody(c) + m := meta.GetByContext(c) + defer func() { + audit.Logger(). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("parsed", audit.ParseOPENAIStreamResponse(buf)). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(m.ToLogrusFields()). + Info("client response") + }() + } var err *model.ErrorWithStatusCode switch relayMode { case relaymode.ImagesGenerations: - err = controller.RelayImageHelper(c, relayMode) + err = ctrl.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: fallthrough case relaymode.AudioTranslation: fallthrough case relaymode.AudioTranscription: - err = controller.RelayAudioHelper(c, relayMode) + err = ctrl.RelayAudioHelper(c, relayMode) default: - err = controller.RelayTextHelper(c) + err = ctrl.RelayTextHelper(c) } return err } -func Relay(c *gin.Context) { +func (ctrl *RelayController) Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := relaymode.GetByPath(c.Request.URL.Path) if config.DebugEnabled { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) } + if config.ClientAuditEnabled { + requestBody, _ := common.GetRequestBody(c) + m := meta.GetByContext(c) + audit.Logger(). + WithField("raw", audit.B64encode(requestBody)). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(m.ToLogrusFields()). + Info("client request") + } channelId := c.GetInt(ctxkey.ChannelId) - userId := c.GetInt("id") - bizErr := relayHelper(c, relayMode) + bizErr := ctrl.relayHelper(c, relayMode) if bizErr == nil { - monitor.Emit(channelId, true) + if ctrl.MonitorInstance != nil { + ctrl.Emit(channelId, true) + } return } lastFailedChannelId := channelId channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + userId := c.GetInt(ctxkey.Id) + go ctrl.processChannelRelayError(ctx, userId, channelId, channelName, bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { @@ -77,15 +127,19 @@ func Relay(c *gin.Context) { } middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) + if err != nil { + logger.Errorf(ctx, "GetRequestBody failed: %+v", err) + break + } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - bizErr = relayHelper(c, relayMode) + bizErr = ctrl.relayHelper(c, relayMode) if bizErr == nil { return } channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go ctrl.processChannelRelayError(ctx, userId, channelId, channelName, bizErr) } if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { @@ -117,13 +171,16 @@ func shouldRetry(c *gin.Context, statusCode int) bool { return true } -func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { +func (ctrl *RelayController) processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { + if ctrl.MonitorInstance == nil { + return + } logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors - if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { - monitor.DisableChannel(channelId, channelName, err.Message) + if ctrl.ShouldDisableChannel(&err.Error, err.StatusCode) { + ctrl.DisableChannel(channelId, channelName, err.Message) } else { - monitor.Emit(channelId, false) + ctrl.Emit(channelId, false) } } diff --git a/go.mod b/go.mod index 1ed937ae5f..8fb66d83e8 100644 --- a/go.mod +++ b/go.mod @@ -20,10 +20,13 @@ require ( github.com/jinzhu/copier v0.4.0 github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.7 + github.com/sirupsen/logrus v1.8.1 github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.9.0 + github.com/tidwall/gjson v1.17.1 golang.org/x/crypto v0.23.0 golang.org/x/image v0.16.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 @@ -73,6 +76,8 @@ require ( github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/smarty/assertions v1.15.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect diff --git a/go.sum b/go.sum index a5aede9569..56ecf850b6 100644 --- a/go.sum +++ b/go.sum @@ -1,40 +1,25 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= -github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= -github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= -github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo= github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ= github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= -github.com/bytedance/sonic v1.11.5 h1:G00FYjjqll5iQ1PYXynbg/hyzqBqavH8Mo9/oTopd9k= -github.com/bytedance/sonic v1.11.5/go.mod h1:X2PC2giUdj/Cv2lliWFLk6c/DUQok5rViJSemeB0wDw= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic/loader v0.1.0/go.mod h1:UmRT+IRTGKz/DAkzcEGzyVqQFJ7H9BqwBO3pm9H/+HY= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cloudwego/base64x v0.1.3 h1:b5J/l8xolB7dyDTTmhJP2oTs5LdrjyrUFuNxdfq5hAg= -github.com/cloudwego/base64x v0.1.3/go.mod h1:1+1K5BUHIQzyapgpF7LwvOGAEDicKtt1umPV+aN8pi8= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= @@ -51,26 +36,16 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/gin-contrib/cors v1.7.1 h1:s9SIppU/rk8enVvkzwiC2VK3UZ/0NNGsWfUKvV55rqs= -github.com/gin-contrib/cors v1.7.1/go.mod h1:n/Zj7B4xyrgk/cX1WCX2dkzFfaNm/xJb6oIUk7WTtps= github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E= -github.com/gin-contrib/gzip v1.0.0 h1:UKN586Po/92IDX6ie5CWLgMI81obiIp5nSP85T3wlTk= -github.com/gin-contrib/gzip v1.0.0/go.mod h1:CtG7tQrPB3vIBo6Gat9FVUsis+1emjvQqd66ME5TdnE= github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE= github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4= -github.com/gin-contrib/sessions v1.0.0 h1:r5GLta4Oy5xo9rAwMHx8B4wLpeRGHMdz9NafzJAdP8Y= -github.com/gin-contrib/sessions v1.0.0/go.mod h1:DN0f4bvpqMQElDdi+gNGScrP2QEI04IErRyMFyorUOI= github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI= github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-contrib/static v1.1.1 h1:XEvBd4DDLG1HBlyPBQU1XO8NlTpw6mgdqcPteetYA5k= -github.com/gin-contrib/static v1.1.1/go.mod h1:yRGmar7+JYvbMLRPIi4H5TVVSBwULfT9vetnVD0IO74= github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4= github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw= -github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= -github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= @@ -78,8 +53,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4= -github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= @@ -87,8 +60,6 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= @@ -147,19 +118,17 @@ github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjY github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= -github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= -github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= -github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= @@ -168,6 +137,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -176,46 +146,41 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= -golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= -golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= -golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= -golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= @@ -228,8 +193,6 @@ gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4c gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.9 h1:wct0gxZIELDk8+ZqF/MVnHLkA1rvYlBWUMv2EdsK1g8= -gorm.io/gorm v1.25.9/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/monitor/monitor.go b/monitor/monitor.go new file mode 100644 index 0000000000..7071ef6065 --- /dev/null +++ b/monitor/monitor.go @@ -0,0 +1,32 @@ +package monitor + +import "github.com/songquanpeng/one-api/relay/model" + +type MonitorInstance interface { + Emit(ChannelId int, success bool) + ShouldDisableChannel(err *model.Error, statusCode int) bool + DisableChannel(channelId int, channelName string, reason string) +} + +type defaultMonitor struct { +} + +func NewMonitorInstance() MonitorInstance { + return &defaultMonitor{} +} + +func (m *defaultMonitor) Emit(channelId int, success bool) { + if success { + metricSuccessChan <- channelId + } else { + metricFailChan <- channelId + } +} + +func (m *defaultMonitor) ShouldDisableChannel(err *model.Error, statusCode int) bool { + return ShouldDisableChannel(err, statusCode) +} + +func (m *defaultMonitor) DisableChannel(channelId int, channelName string, reason string) { + DisableChannel(channelId, channelName, reason) +} diff --git a/relay/billing/billing-instance.go b/relay/billing/billing-instance.go new file mode 100644 index 0000000000..9130b4c13c --- /dev/null +++ b/relay/billing/billing-instance.go @@ -0,0 +1,160 @@ +package billing + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +// Bookkeeper 记账员逻辑,用于处理用户的配额消费 +// 预扣配额检测逻辑: +// +// 开始请求前,根据不同的请求类型,预先计算需要消费的配额,根据请求用户和token的配额余量来判断是否有足够的配额来满足这个请求 +// 如果余量配额不能满足这个请求,直接返回错误, 如果余量配额可以满足这个请求,那么预先消费这个配额,然后开始请求, 如果余量远远超过这个请求,那么不需要预先消费配额 +// 由于预先计算的配额不是实际消费的配额,所以需要在请求结束后,根据实际消费的配额来更新用户和token的配额,退费或者扣费。 +type Bookkeeper interface { + // 获取模型的费率 + ModelRatio(model string) float64 + // 获取组的费率 + GroupRation(group string) float64 + // 获取模型的补全费率 + ModelCompletionRatio(model string) float64 + // 根据消费记录,扣除用户,token 的配额 + Consume(ctx context.Context, consumeLog *ConsumeLog) + // 预消费配额, 当用户配额不足时,预消费配额, 预消费成功返回预消费的配额,失败返回错误, 如果预消费的配额为0,表示用户有足够的配额 + PreConsumeQuota(ctx context.Context, preConsumedQuota int64, userId, tokenId int) (int64, *relaymodel.ErrorWithStatusCode) + // 退回预消费的配额, 这通常在调用上游api失败的时候执行 + RefundQuota(ctx context.Context, preConsumedQuota int64, tokenId int) + + // 检测用户是否有足够的配额 + // UserHasEnoughQuota(ctx context.Context, userID int, quota int64) bool + // 检测用户是否有远远超过需求的配额, 如果用户的配额远远超过需求,那么不需要预消费配额 + // UserHasMuchMoreQuota(ctx context.Context, userID int, quota int64) bool +} + +type defaultBookkeeper struct { +} + +func NewBookkeeper() Bookkeeper { + return &defaultBookkeeper{} +} + +func (b *defaultBookkeeper) ModelRatio(model string) float64 { + return billingratio.GetModelRatio(model) +} + +func (b *defaultBookkeeper) GroupRation(group string) float64 { + return billingratio.GetGroupRatio(group) +} + +func (b *defaultBookkeeper) ModelCompletionRatio(model string) float64 { + return billingratio.GetCompletionRatio(model) +} + +func (b *defaultBookkeeper) Ratio(group, model string) float64 { + modelRatio := billingratio.GetModelRatio(model) + groupRatio := billingratio.GetGroupRatio(group) + return modelRatio * groupRatio +} + +// ConsumeLog 消费记录实体 +type ConsumeLog struct { + UserId int + ChannelId int + PromptTokens int + CompletionTokens int + ModelName string + TokenId int + TokenName string + Quota int64 + Content string + PreConsumedQuota int64 +} + +func (b *defaultBookkeeper) UserHasEnoughQuota(ctx context.Context, userID int, quota int64) bool { + userQuota, err := model.CacheGetUserQuota(ctx, userID) + if err != nil { + return false + } + return userQuota >= quota +} + +func (b *defaultBookkeeper) UserHasMuchMoreQuota(ctx context.Context, userID int, quota int64) bool { + userQuota, err := model.CacheGetUserQuota(ctx, userID) + if err != nil { + return false + } + return userQuota > 100*quota +} + +func (b *defaultBookkeeper) Consume(ctx context.Context, consumeLog *ConsumeLog) { + // 更新 access_token 的配额 + quotaDelta := consumeLog.Quota - consumeLog.PreConsumedQuota + err := model.PostConsumeTokenQuota(consumeLog.TokenId, quotaDelta) + if err != nil { + logger.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(ctx, consumeLog.UserId) + if err != nil { + logger.SysError("error update user quota cache: " + err.Error()) + } + // 更新用户的配额 + model.UpdateUserUsedQuotaAndRequestCount(consumeLog.UserId, consumeLog.Quota) + // 更新渠道的配额 + model.UpdateChannelUsedQuota(consumeLog.ChannelId, consumeLog.Quota) + // 记录消费日志 + model.RecordConsumeLog( + ctx, + consumeLog.UserId, + consumeLog.ChannelId, + consumeLog.PromptTokens, + consumeLog.CompletionTokens, + consumeLog.ModelName, + consumeLog.TokenName, + consumeLog.Quota, + consumeLog.Content, + ) +} + +func (b *defaultBookkeeper) PreConsumeQuota(ctx context.Context, preConsumedQuota int64, userId, tokenId int) (int64, *relaymodel.ErrorWithStatusCode) { + userQuota, err := model.CacheGetUserQuota(ctx, userId) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + if userQuota-preConsumedQuota < 0 { + return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + return preConsumedQuota, nil +} + +func (b *defaultBookkeeper) RefundQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { + if preConsumedQuota != 0 { + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) + } + } +} diff --git a/relay/billing/billing.go b/relay/billing/billing.go index a99d37ee70..28587661bc 100644 --- a/relay/billing/billing.go +++ b/relay/billing/billing.go @@ -3,10 +3,12 @@ package billing import ( "context" "fmt" + "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" ) +// ReturnPreConsumedQuota 在请求失败的时候,退回预消费的配额 func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { if preConsumedQuota != 0 { go func(ctx context.Context) { diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 8f9708d080..f500a95b70 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,36 +7,35 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/audit" "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" - billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) -func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { - ctx := c.Request.Context() +func (rl *defaultRelay) RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { meta := meta.GetByContext(c) audioModel := "whisper-1" tokenId := c.GetInt(ctxkey.TokenId) channelType := c.GetInt(ctxkey.Channel) - channelId := c.GetInt(ctxkey.ChannelId) + // channelId := c.GetInt(ctxkey.ChannelId) userId := c.GetInt(ctxkey.Id) group := c.GetString(ctxkey.Group) - tokenName := c.GetString(ctxkey.TokenName) + // tokenName := c.GetString(ctxkey.TokenName) var ttsRequest openai.TextToSpeechRequest if relayMode == relaymode.AudioSpeech { @@ -53,58 +52,45 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := billingratio.GetModelRatio(audioModel) - groupRatio := billingratio.GetGroupRatio(group) - ratio := modelRatio * groupRatio - var quota int64 - var preConsumedQuota int64 + var ( + modelRatio float64 + groupRatio float64 + ratio float64 + quota int64 + preConsumeQuota int64 + preConsumedQuota int64 + bizErr *relaymodel.ErrorWithStatusCode + ) + + if rl.Bookkeeper != nil { + modelRatio = rl.ModelRatio(audioModel) + groupRatio = rl.GroupRation(group) + ratio = modelRatio * groupRatio + } + switch relayMode { + // speech 类型,消费的配额直接根据输入的文本长度计算 case relaymode.AudioSpeech: - preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) - quota = preConsumedQuota + preConsumeQuota = int64(float64(len(ttsRequest.Input)) * ratio) + quota = preConsumeQuota + // 其他类型,假设消费的配额是预设的配额的 ratio 倍 default: - preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) + preConsumeQuota = int64(float64(config.PreConsumedQuota) * ratio) } - userQuota, err := model.CacheGetUserQuota(ctx, userId) - if err != nil { - return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - } - - // Check if user quota is enough - if userQuota-preConsumedQuota < 0 { - return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + if rl.Bookkeeper != nil { + preConsumedQuota, bizErr = rl.PreConsumeQuota(c, preConsumeQuota, userId, tokenId) + if bizErr != nil { + return bizErr } } + succeed := false defer func() { if succeed { return } - if preConsumedQuota > 0 { - // we need to roll back the pre-consumed quota - defer func(ctx context.Context) { - go func() { - // negative means add quota back for token & user - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) - } - }() - }(c.Request.Context()) + if rl.Bookkeeper != nil { + rl.Bookkeeper.RefundQuota(c.Request.Context(), preConsumedQuota, tokenId) } }() @@ -140,8 +126,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } requestBody := &bytes.Buffer{} - _, err = io.Copy(requestBody, c.Request.Body) - if err != nil { + if _, err := io.Copy(requestBody, c.Request.Body); err != nil { return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) @@ -151,6 +136,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } + if config.UpstreamAuditEnabled { + audit.Logger(). + WithField("stage", "upstream request"). + WithField("raw", audit.B64encode(requestBody.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream request") + } if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api @@ -168,6 +161,17 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + if config.UpstreamAuditEnabled { + buf := audit.CaptureHTTPResponseBody(resp) + defer func() { + audit.Logger(). + WithField("stage", "upstream response"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream response") + }() + } err = req.Body.Close() if err != nil { @@ -220,9 +224,28 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return RelayErrorHandler(resp) } succeed = true - quotaDelta := quota - preConsumedQuota + // quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { - go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + // go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + // post-consume quota + if rl.Bookkeeper != nil { + // go postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) + + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + consumeLog := &billing.ConsumeLog{ + UserId: meta.UserId, + ChannelId: meta.ChannelId, + ModelName: audioModel, + TokenName: c.GetString(ctxkey.TokenName), + TokenId: meta.TokenId, + Quota: quota, + Content: logContent, + PromptTokens: int(preConsumeQuota), + CompletionTokens: 0, + PreConsumedQuota: preConsumedQuota, + } + rl.Bookkeeper.Consume(c, consumeLog) + } }(c.Request.Context()) for k, v := range resp.Header { diff --git a/relay/controller/controller.go b/relay/controller/controller.go new file mode 100644 index 0000000000..7b1d1479ac --- /dev/null +++ b/relay/controller/controller.go @@ -0,0 +1,30 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/billing" + "github.com/songquanpeng/one-api/relay/model" +) + +type Options struct { + EnableBilling bool +} + +// RelayInstance is the interface for relay controller +type RelayInstance interface { + RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode + RelayImageHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode + RelayAudioHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode +} + +type defaultRelay struct { + billing.Bookkeeper +} + +func NewRelayInstance(opts Options) RelayInstance { + relay := &defaultRelay{} + if opts.EnableBilling { + relay.Bookkeeper = billing.NewBookkeeper() + } + return relay +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go index dccff486cb..6176efe949 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,6 +4,10 @@ import ( "context" "errors" "fmt" + "math" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" @@ -16,9 +20,6 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "math" - "net/http" - "strings" ) func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { @@ -208,10 +209,7 @@ func getMappedModelName(modelName string, mapping map[string]string) (string, bo func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { if resp == nil { - if meta.ChannelType == channeltype.AwsClaude { - return false - } - return true + return meta.ChannelType != channeltype.Azure } if resp.StatusCode != http.StatusOK { return true diff --git a/relay/controller/image.go b/relay/controller/image.go index 691c7c0e25..eab3651b6f 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -4,20 +4,23 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/audit" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) func isWithinRange(element string, value int) bool { @@ -29,33 +32,40 @@ func isWithinRange(element string, value int) bool { return value >= min && value <= max } -func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { - ctx := c.Request.Context() +func (rl *defaultRelay) RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { meta := meta.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { - logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) + logger.Errorf(c, "getImageRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name - var isModelMapped bool + var ( + isModelMapped bool + preConsumeQuota int64 + preConsumedQuota int64 + imageCostRatio float64 + bizErr *relaymodel.ErrorWithStatusCode + ) meta.OriginModelName = imageRequest.Model imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) meta.ActualModelName = imageRequest.Model // model validation - bizErr := validateImageRequest(imageRequest, meta) + bizErr = validateImageRequest(imageRequest, meta) if bizErr != nil { return bizErr } - imageCostRatio, err := getImageCostRatio(imageRequest) - if err != nil { - return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) + if rl.Bookkeeper != nil { + imageCostRatio, err = getImageCostRatio(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) + } } - imageModel := imageRequest.Model + originModel := imageRequest.Model // Convert the original image model imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName) c.Set("response_format", imageRequest.ResponseFormat) @@ -94,51 +104,89 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = bytes.NewBuffer(jsonStr) } - modelRatio := billingratio.GetModelRatio(imageModel) - groupRatio := billingratio.GetGroupRatio(meta.Group) - ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - - quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) - - if userQuota-quota < 0 { - return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + if rl.Bookkeeper != nil { + modelRatio := rl.ModelRatio(originModel) + groupRatio := rl.GroupRation(meta.Group) + ratio := modelRatio * groupRatio + preConsumeQuota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + preConsumedQuota, bizErr = rl.PreConsumeQuota(c, preConsumeQuota, meta.UserId, meta.TokenId) + if bizErr != nil { + logger.Warnf(c, "preConsumeQuota failed: %+v", *bizErr) + return bizErr + } } + refund := func() { + if rl.Bookkeeper != nil && preConsumedQuota > 0 { + rl.RefundQuota(c, preConsumedQuota, meta.TokenId) + } + } + if config.UpstreamAuditEnabled { + buf := bytes.Buffer{} + requestBody = io.TeeReader(requestBody, &buf) + defer func() { + audit.Logger(). + WithField("stage", "upstream request"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream request") + }() + } // do request resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { - logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + logger.Errorf(c, "DoRequest failed: %s", err.Error()) + refund() return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + if config.UpstreamAuditEnabled { + buf := audit.CaptureHTTPResponseBody(resp) + defer func() { + audit.Logger(). + WithField("stage", "upstream response"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream response") + }() + } defer func(ctx context.Context) { if resp != nil && resp.StatusCode != http.StatusOK { return } - - err := model.PostConsumeTokenQuota(meta.TokenId, quota) - if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(ctx, meta.UserId) - if err != nil { - logger.SysError("error update user quota cache: " + err.Error()) + if rl.Bookkeeper == nil { + return } - if quota != 0 { + modelRatio := rl.ModelRatio(originModel) + groupRatio := rl.GroupRation(meta.Group) + ratio := modelRatio * groupRatio + consumedQuota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + + if consumedQuota != 0 { tokenName := c.GetString(ctxkey.TokenName) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) - channelId := c.GetInt(ctxkey.ChannelId) - model.UpdateChannelUsedQuota(channelId, quota) + consumeLog := &billing.ConsumeLog{ + UserId: meta.UserId, + ChannelId: meta.ChannelId, + ModelName: imageRequest.Model, + TokenName: tokenName, + TokenId: meta.TokenId, + Quota: consumedQuota, + Content: logContent, + PromptTokens: 0, + CompletionTokens: 0, + PreConsumedQuota: preConsumedQuota, + } + rl.Bookkeeper.Consume(c, consumeLog) } - }(c.Request.Context()) + }(c) // do response _, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { - logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + logger.Errorf(c, "respErr is not nil: %+v", respErr) return respErr } diff --git a/relay/controller/text.go b/relay/controller/text.go index 6ed19b1de8..272d6195c5 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -4,27 +4,30 @@ import ( "bytes" "encoding/json" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/audit" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/billing" - billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) -func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { - ctx := c.Request.Context() +func (rl *defaultRelay) RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { meta := meta.GetByContext(c) // get & validate textRequest textRequest, err := getAndValidateTextRequest(c, meta.Mode) if err != nil { - logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error()) + logger.Errorf(c, "getAndValidateTextRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) } meta.IsStream = textRequest.Stream @@ -35,18 +38,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := billingratio.GetModelRatio(textRequest.Model) - groupRatio := billingratio.GetGroupRatio(meta.Group) - ratio := modelRatio * groupRatio - // pre-consume quota - promptTokens := getPromptTokens(textRequest, meta.Mode) - meta.PromptTokens = promptTokens - preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta) - if bizErr != nil { - logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr) - return bizErr + var ( + preConsumedQuota int64 + modelRatio float64 + groupRatio float64 + ratio float64 + ) + if rl.Bookkeeper != nil { + modelRatio = rl.ModelRatio(textRequest.Model) + groupRatio = rl.GroupRation(meta.Group) + ratio = modelRatio * groupRatio + // pre-consume quota + meta.PromptTokens = getPromptTokens(textRequest, meta.Mode) + preConsumeQuota := getPreConsumedQuota(textRequest, meta.PromptTokens, ratio) + consumedQuota, bizErr := rl.PreConsumeQuota(c, preConsumeQuota, meta.UserId, meta.TokenId) + if bizErr != nil { + logger.Warnf(c, "preConsumeQuota failed: %+v", *bizErr) + return bizErr + } + preConsumedQuota = consumedQuota } - adaptor := relay.GetAdaptor(meta.APIType) if adaptor == nil { return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) @@ -76,29 +87,75 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } - logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) + logger.Debugf(c, "converted request: \n%s", string(jsonData)) requestBody = bytes.NewBuffer(jsonData) } + if config.UpstreamAuditEnabled { + buf := bytes.Buffer{} + requestBody = io.TeeReader(requestBody, &buf) + defer func() { + audit.Logger(). + WithField("stage", "upstream request"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream request") + }() + } + // do request resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { - logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + logger.Errorf(c, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + if config.UpstreamAuditEnabled { + buf := audit.CaptureHTTPResponseBody(resp) + defer func() { + audit.Logger(). + WithField("stage", "upstream response"). + WithField("raw", audit.B64encode(buf.Bytes())). + WithField("requestid", c.GetString(helper.RequestIdKey)). + WithFields(meta.ToLogrusFields()). + Info("upstream response") + }() + } + refund := func() { + if rl.Bookkeeper != nil && preConsumedQuota > 0 { + rl.RefundQuota(c, preConsumedQuota, meta.TokenId) + } + } if isErrorHappened(meta, resp) { - billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + refund() return RelayErrorHandler(resp) } // do response usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { - logger.Errorf(ctx, "respErr is not nil: %+v", respErr) - billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + logger.Errorf(c, "respErr is not nil: %+v", respErr) + refund() return respErr } // post-consume quota - go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) + if rl.Bookkeeper != nil { + // go postConsumeQuota(c, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) + completionRatio := rl.ModelCompletionRatio(textRequest.Model) + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) + consumeLog := &billing.ConsumeLog{ + UserId: meta.UserId, + ChannelId: meta.ChannelId, + ModelName: textRequest.Model, + TokenName: c.GetString(ctxkey.TokenName), + TokenId: meta.TokenId, + Quota: usage.Quota(completionRatio, ratio), + Content: logContent, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + PreConsumedQuota: preConsumedQuota, + } + rl.Bookkeeper.Consume(c, consumeLog) + } return nil } diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index 9714ebb5e6..032c2a11ee 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -1,12 +1,13 @@ package meta import ( + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/relaymode" - "strings" ) type Meta struct { @@ -29,6 +30,29 @@ type Meta struct { PromptTokens int // only for DoResponse } +func (m *Meta) ToLogrusFields() map[string]interface{} { + return map[string]interface{}{ + "mode": m.Mode, + "channel_type": m.ChannelType, + "channel_id": m.ChannelId, + "token_id": m.TokenId, + "token_name": m.TokenName, + "user_id": m.UserId, + "group": m.Group, + "model_mapping": m.ModelMapping, + "base_url": m.BaseURL, + "api_key": m.APIKey, + "api_type": m.APIType, + "config": m.Config, + "is_stream": m.IsStream, + "origin_model_name": m.OriginModelName, + "actual_model_name": m.ActualModelName, + "request_url_path": m.RequestURLPath, + "prompt_tokens": m.PromptTokens, + } + +} + func GetByContext(c *gin.Context) *Meta { meta := Meta{ Mode: relaymode.GetByPath(c.Request.URL.Path), diff --git a/relay/model/misc.go b/relay/model/misc.go index 163bc398b7..5a114276c7 100644 --- a/relay/model/misc.go +++ b/relay/model/misc.go @@ -1,11 +1,21 @@ package model +import "math" + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } +func (u *Usage) Quota(completionRatio, finalRatio float64) int64 { + quota := int64(math.Ceil((float64(u.PromptTokens) + float64(u.CompletionTokens)*completionRatio) * finalRatio)) + if finalRatio != 0 && quota <= 0 { + quota = 1 + } + return quota +} + type Error struct { Message string `json:"message"` Type string `json:"type"` diff --git a/router/relay.go b/router/relay.go index 65072c869b..d4391b7865 100644 --- a/router/relay.go +++ b/router/relay.go @@ -1,6 +1,7 @@ package router import ( + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/middleware" @@ -17,19 +18,26 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1") + opt := controller.Options{ + EnableMonitor: config.EnableMetric, + EnableBilling: config.EnableBilling, + Debug: config.DebugEnabled, + } + ctrl := controller.NewRelayController(opt) + relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) { - relayV1Router.POST("/completions", controller.Relay) - relayV1Router.POST("/chat/completions", controller.Relay) - relayV1Router.POST("/edits", controller.Relay) - relayV1Router.POST("/images/generations", controller.Relay) + relayV1Router.POST("/completions", ctrl.Relay) + relayV1Router.POST("/chat/completions", ctrl.Relay) + relayV1Router.POST("/edits", ctrl.Relay) + relayV1Router.POST("/images/generations", ctrl.Relay) relayV1Router.POST("/images/edits", controller.RelayNotImplemented) relayV1Router.POST("/images/variations", controller.RelayNotImplemented) - relayV1Router.POST("/embeddings", controller.Relay) - relayV1Router.POST("/engines/:model/embeddings", controller.Relay) - relayV1Router.POST("/audio/transcriptions", controller.Relay) - relayV1Router.POST("/audio/translations", controller.Relay) - relayV1Router.POST("/audio/speech", controller.Relay) + relayV1Router.POST("/embeddings", ctrl.Relay) + relayV1Router.POST("/engines/:model/embeddings", ctrl.Relay) + relayV1Router.POST("/audio/transcriptions", ctrl.Relay) + relayV1Router.POST("/audio/translations", ctrl.Relay) + relayV1Router.POST("/audio/speech", ctrl.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) @@ -41,7 +49,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented) relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented) relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) - relayV1Router.POST("/moderations", controller.Relay) + relayV1Router.POST("/moderations", ctrl.Relay) relayV1Router.POST("/assistants", controller.RelayNotImplemented) relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented) relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented)