Skip to content

Commit

Permalink
feat: 支持上传图片到钉钉平台,在图片生成流程中使用钉钉的图片 CDN 能力 (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
chzealot authored May 22, 2023
1 parent dfda88d commit 2eda9e8
Show file tree
Hide file tree
Showing 13 changed files with 434 additions and 8 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ $ docker run -itd --name chatgpt -p 8090:8090 \
-e SENSITIVE_WORDS="aa,bb" \
-e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \
-e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \
-e DINGTALK_CREDENTIALS="your_client_id1:secret1,your_client_id2:secret2" \
-e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/)
,觉得不错你可以来波素质三连." \
--restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
Expand Down Expand Up @@ -541,6 +542,15 @@ azure_resource_name: "xxxx"
azure_deployment_name: "xxxx"
azure_openai_token: "xxxx"

# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
# 建议采用 credentials 代替 app_secrets 配置项,以获得钉钉 OpenAPI 访问能力
credentials:
-
client_id: "put-your-client-id-here"
client_secret: "put-your-client-secret-here"

```

## 常见问题
Expand Down
7 changes: 7 additions & 0 deletions config.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ azure_resource_name: "xxxx"
azure_deployment_name: "xxxx"
azure_openai_token: "xxxx"

# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
#credentials:
# -
# client_id: "put-your-client-id-here"
# client_secret: "put-your-client-secret-here"
19 changes: 19 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import (
"gopkg.in/yaml.v2"
)

type Credential struct {
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
}

// Configuration 项目配置
type Configuration struct {
// 日志级别,info或者debug
Expand Down Expand Up @@ -62,6 +67,8 @@ type Configuration struct {
AzureResourceName string `yaml:"azure_resource_name"`
AzureDeploymentName string `yaml:"azure_deployment_name"`
AzureOpenAIToken string `yaml:"azure_openai_token"`
// 钉钉应用鉴权凭据
Credentials []Credential `yaml:"credentials"`
}

var config *Configuration
Expand Down Expand Up @@ -190,6 +197,18 @@ func LoadConfig() *Configuration {
if azureOpenaiToken != "" {
config.AzureOpenAIToken = azureOpenaiToken
}
credentials := os.Getenv("DINGTALK_CREDENTIALS")
if credentials != "" {
if config.Credentials == nil {
config.Credentials = []Credential{}
}
for _, idSecret := range strings.Split(credentials, ",") {
items := strings.SplitN(idSecret, ":", 2)
if len(items) == 2 {
config.Credentials = append(config.Credentials, Credential{ClientID: items[0], ClientSecret: items[1]})
}
}
}

})

Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ services:
AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai"
AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai"
AZURE_OPENAI_TOKEN: "" # Azure token
DINGTALK_CREDENTIALS: "" # 钉钉应用访问凭证,比如 "client_id1:secret1,client_id2:secret2"
HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义
volumes:
- ./data:/app/data
Expand Down
10 changes: 9 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ func Start() {
return
}
// 先校验回调是否合法
clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign"))
if !checkOk {
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
return
}
// 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI
c.Set(public.DingTalkClientIdKeyName, clientId)
// 为了兼容存量老用户,暂时保留 public.CheckRequest 方法,将来升级到 Stream 模式后,建议去除该方法,采用上面的 CheckRequestWithCredentials
if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" {
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
return
Expand Down Expand Up @@ -114,7 +122,7 @@ func Start() {
// 除去帮助之外的逻辑分流在这里处理
switch {
case strings.HasPrefix(msgObj.Text.Content, "#图片"):
err := process.ImageGenerate(&msgObj)
err := process.ImageGenerate(c, &msgObj)
if err != nil {
logger.Warning(fmt.Errorf("process request: %v", err))
return
Expand Down
20 changes: 17 additions & 3 deletions pkg/chatgpt/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package chatgpt

import (
"bytes"
"context"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
"github.com/pandodao/tokenizer-go"
"image/png"
"os"
Expand Down Expand Up @@ -218,7 +222,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
return resp.Choices[0].Text, nil
}
}
func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) {
model := public.Config.Model
if model == openai.GPT3Dot5Turbo0301 ||
model == openai.GPT3Dot5Turbo ||
Expand Down Expand Up @@ -247,6 +251,13 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
}

imageName := time.Now().Format("20060102-150405") + ".png"
clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string)
client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId)
mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId))
if client != nil {
mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng)
}

err = os.MkdirAll("data/images", 0755)
if err != nil {
return "", err
Expand All @@ -260,8 +271,11 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
if err := png.Encode(file, imgData); err != nil {
return "", err
}

return public.Config.ServiceURL + "/images/" + imageName, nil
if uploadErr == nil {
return mediaResult.MediaID, nil
} else {
return public.Config.ServiceURL + "/images/" + imageName, nil
}
}
return "", nil
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/chatgpt/export.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package chatgpt

import (
"context"
"time"

"github.com/avast/retry-go"
Expand Down Expand Up @@ -58,7 +59,7 @@ func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error
}

// ImageQa 生成图片
func ImageQa(question, userId string) (answer string, err error) {
func ImageQa(ctx context.Context, question, userId string) (answer string, err error) {
chat := New(userId)
defer chat.Close()
// 定义一个重试策略
Expand All @@ -70,7 +71,7 @@ func ImageQa(question, userId string) (answer string, err error) {
// 使用重试策略进行重试
err = retry.Do(
func() error {
answer, err = chat.GenreateImage(question)
answer, err = chat.GenreateImage(ctx, question)
if err != nil {
return err
}
Expand Down
213 changes: 213 additions & 0 deletions pkg/dingbot/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package dingbot

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/eryajf/chatgpt-dingtalk/config"
"io"
"mime/multipart"
"net/http"
url2 "net/url"
"sync"
"time"
)

// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
const (
MediaTypeImage string = "image"
MediaTypeVoice string = "voice"
MediaTypeVideo string = "video"
MediaTypeFile string = "file"
)
const (
MimeTypeImagePng string = "image/png"
)

type MediaUploadResult struct {
ErrorCode int64 `json:"errcode"`
ErrorMessage string `json:"errmsg"`
MediaID string `json:"media_id"`
CreatedAt int64 `json:"created_at"`
Type string `json:"type"`
}

type OAuthTokenResult struct {
ErrorCode int `json:"errcode"`
ErrorMessage string `json:"errmsg"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}

type DingTalkClientInterface interface {
GetAccessToken() (string, error)
UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error)
}

type DingTalkClientManagerInterface interface {
GetClientByOAuthClientID(clientId string) DingTalkClientInterface
}

type DingTalkClient struct {
Credential config.Credential
AccessToken string
expireAt int64
mutex sync.Mutex
}

type DingTalkClientManager struct {
Credentials []config.Credential
Clients map[string]*DingTalkClient
mutex sync.Mutex
}

func NewDingTalkClient(credential config.Credential) *DingTalkClient {
return &DingTalkClient{
Credential: credential,
}
}

func NewDingTalkClientManager(conf *config.Configuration) *DingTalkClientManager {
clients := make(map[string]*DingTalkClient)

if conf != nil && conf.Credentials != nil {
for _, credential := range conf.Credentials {
clients[credential.ClientID] = NewDingTalkClient(credential)
}
}
return &DingTalkClientManager{
Credentials: conf.Credentials,
Clients: clients,
}
}

func (m *DingTalkClientManager) GetClientByOAuthClientID(clientId string) DingTalkClientInterface {
m.mutex.Lock()
defer m.mutex.Unlock()
if client, ok := m.Clients[clientId]; ok {
return client
}
return nil
}

func (c *DingTalkClient) GetAccessToken() (string, error) {
accessToken := ""
{
// 先查询缓存
c.mutex.Lock()
now := time.Now().Unix()
if c.expireAt > 0 && c.AccessToken != "" && (now+60) < c.expireAt {
// 预留一分钟有效期避免在Token过期的临界点调用接口出现401错误
accessToken = c.AccessToken
}
c.mutex.Unlock()
}
if accessToken != "" {
return accessToken, nil
}

tokenResult, err := c.getAccessTokenFromDingTalk()
if err != nil {
return "", err
}

{
// 更新缓存
c.mutex.Lock()
c.AccessToken = tokenResult.AccessToken
c.expireAt = time.Now().Unix() + int64(tokenResult.ExpiresIn)
c.mutex.Unlock()
}
return tokenResult.AccessToken, nil
}

func (c *DingTalkClient) UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) {
// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
accessToken, err := c.GetAccessToken()
if err != nil {
return nil, err
}
if len(accessToken) == 0 {
return nil, errors.New("empty access token")
}
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("media", filename)
if err != nil {
return nil, err
}
_, err = part.Write(content)
writer.WriteField("type", mediaType)
err = writer.Close()
if err != nil {
return nil, err
}

// Create a new HTTP request to upload the media file
url := fmt.Sprintf("https://oapi.dingtalk.com/media/upload?access_token=%s", url2.QueryEscape(accessToken))
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", writer.FormDataContentType())

// Send the HTTP request and parse the response
client := &http.Client{
Timeout: time.Second * 60,
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()

// Parse the response body as JSON and extract the media ID
media := &MediaUploadResult{}
bodyBytes, err := io.ReadAll(res.Body)
json.Unmarshal(bodyBytes, media)
if err != nil {
return nil, err
}
if media.ErrorCode != 0 {
return nil, errors.New(media.ErrorMessage)
}
return media, nil
}

func (c *DingTalkClient) getAccessTokenFromDingTalk() (*OAuthTokenResult, error) {
// OpenAPI doc: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
apiUrl := "https://oapi.dingtalk.com/gettoken"
queryParams := url2.Values{}
queryParams.Add("appkey", c.Credential.ClientID)
queryParams.Add("appsecret", c.Credential.ClientSecret)

// Create a new HTTP request to get the AccessToken
req, err := http.NewRequest("GET", apiUrl+"?"+queryParams.Encode(), nil)
if err != nil {
return nil, err
}

// Send the HTTP request and parse the response body as JSON
client := http.Client{
Timeout: time.Second * 60,
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
tokenResult := &OAuthTokenResult{}
err = json.Unmarshal(body, tokenResult)
if err != nil {
return nil, err
}
if tokenResult.ErrorCode != 0 {
return nil, errors.New(tokenResult.ErrorMessage)
}
return tokenResult, nil
}
Loading

0 comments on commit 2eda9e8

Please sign in to comment.