Skip to content

Commit

Permalink
feat: automatically disable channel when error occurred (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
songquanpeng committed May 15, 2023
1 parent 44ebae1 commit a1f6138
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 19 deletions.
5 changes: 5 additions & 0 deletions common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ var TurnstileSecretKey = ""

var QuotaForNewUser = 100

var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false

var RootUserEmail = ""

const (
RoleGuestUser = 0
RoleCommonUser = 1
Expand Down
31 changes: 20 additions & 11 deletions controller/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,20 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false

// disable & notify
func disableChannel(channelId int, channelName string, err error) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error())
err = common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}

func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Lock()
if testAllChannelsRunning {
Expand All @@ -280,8 +294,10 @@ func testAllChannels(c *gin.Context) error {
return err
}
testRequest := buildTestRequest(c)
var disableThreshold int64 = 5000 // TODO: make it configurable
email := model.GetRootUserEmail()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
Expand All @@ -295,18 +311,11 @@ func testAllChannels(c *gin.Context) error {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
// disable & notify
channel.UpdateStatus(common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id)
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error())
err = common.SendEmail(subject, email, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
disableChannel(channel.Id, channel.Name, err)
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", email, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
Expand Down
10 changes: 10 additions & 0 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
Expand Down Expand Up @@ -74,6 +75,11 @@ func Relay(c *gin.Context) {
"type": "one_api_error",
},
})
if common.AutomaticDisableChannelEnabled {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err)
}
}
}

Expand Down Expand Up @@ -256,6 +262,10 @@ func relayHelper(c *gin.Context) error {
if err != nil {
return err
}
if textResponse.Error.Type != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s",
textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message))
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
Expand Down
2 changes: 2 additions & 0 deletions middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func Distribute() func(c *gin.Context) {
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
c.Set("base_url", channel.BaseURL)
Expand Down
14 changes: 7 additions & 7 deletions model/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
}
}

func (channel *Channel) UpdateStatus(status int) {
err := DB.Model(channel).Update("status", status).Error
if err != nil {
common.SysError("failed to update response time: " + err.Error())
}
}

func (channel *Channel) Delete() error {
var err error
err = DB.Delete(channel).Error
return err
}

func UpdateChannelStatusById(id int, status int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
}
6 changes: 6 additions & 0 deletions model/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func InitOptionMap() {
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
Expand Down Expand Up @@ -114,6 +116,8 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
}
}
switch key {
Expand Down Expand Up @@ -156,6 +160,8 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
}
return err
}
29 changes: 28 additions & 1 deletion web/src/components/SystemSetting.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ const SystemSetting = () => {
RegisterEnabled: '',
QuotaForNewUser: 0,
ModelRatio: '',
TopUpLink: ''
TopUpLink: '',
AutomaticDisableChannelEnabled: '',
ChannelDisableThreshold: 0,
});
let originInputs = {};
let [loading, setLoading] = useState(false);
Expand Down Expand Up @@ -62,6 +64,7 @@ const SystemSetting = () => {
case 'WeChatAuthEnabled':
case 'TurnstileCheckEnabled':
case 'RegisterEnabled':
case 'AutomaticDisableChannelEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
default:
Expand Down Expand Up @@ -298,6 +301,30 @@ const SystemSetting = () => {
</Form.Group>
<Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
<Divider />
<Header as='h3'>
监控设置
</Header>
<Form.Group widths={3}>
<Form.Input
label='最长回应时间'
name='ChannelDisableThreshold'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.ChannelDisableThreshold}
type='number'
min='0'
placeholder='单位秒,当运行通道全部测试时,超过此时间将自动禁用通道'
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
label='失败时自动禁用通道'
name='AutomaticDisableChannelEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Divider />
<Header as='h3'>
配置 SMTP
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
Expand Down

0 comments on commit a1f6138

Please sign in to comment.