From 6317c0d314741ae70f994438fb0800a0f0adad70 Mon Sep 17 00:00:00 2001
From: winlin <winlinvip@gmail.com>
Date: Fri, 13 Sep 2024 09:00:47 +0800
Subject: [PATCH] AI: Support OpenAI o1-preview model. v5.15.22

---
 .gitignore          |   1 +
 DEVELOPER.md        |   1 +
 platform/ai-talk.go | 210 ++++++++++++++++++++++++++++++--------------
 platform/openai.go  |  34 +++++++
 platform/utils.go   |   1 +
 5 files changed, 182 insertions(+), 65 deletions(-)
 create mode 100644 platform/openai.go

diff --git a/.gitignore b/.gitignore
index d0bbd55e..115237b8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -140,3 +140,4 @@ __pycache__
 /.*.txt
 /*.txt
 .tmp
+__debug_bin*
diff --git a/DEVELOPER.md b/DEVELOPER.md
index 3870e133..ba3015fb 100644
--- a/DEVELOPER.md
+++ b/DEVELOPER.md
@@ -1285,6 +1285,7 @@ The following are the update records for the Oryx server.
     * Dubbing: Fix bug when changing ASR segment size. v5.15.20
     * Dubbing: Refine the window of text. [v5.15.20](https://github.com/ossrs/oryx/releases/tag/v5.15.20)
     * Dubbing: Support space key to play/pause. v5.15.21
+    * AI: Support OpenAI o1-preview model. v5.15.22
 * v5.14:
     * Merge features and bugfix from releases. v5.14.1
     * Dubbing: Support VoD dubbing for multiple languages. [v5.14.2](https://github.com/ossrs/oryx/releases/tag/v5.14.2)
diff --git a/platform/ai-talk.go b/platform/ai-talk.go
index d75a5a1d..deb62dcb 100644
--- a/platform/ai-talk.go
+++ b/platform/ai-talk.go
@@ -24,6 +24,7 @@ import (
 	ohttp "github.com/ossrs/go-oryx-lib/http"
 	"github.com/ossrs/go-oryx-lib/logger"
 	"github.com/sashabaranov/go-openai"
+
 	// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
 	"github.com/go-redis/redis/v8"
 )
@@ -132,8 +133,18 @@ func (v *openaiChatService) RequestChat(ctx context.Context, sreq *StageRequest,
 
 	system := stage.prompt
 	system += fmt.Sprintf(" Keep your reply neat, limiting the reply to %v words.", stage.replyLimit)
-	messages := []openai.ChatCompletionMessage{
-		{Role: openai.ChatMessageRoleSystem, Content: system},
+	messages := []openai.ChatCompletionMessage{}
+
+	// If not support system message, use User message.
+	model := stage.chatModel
+	if gptModelSupportSystem(model) {
+		messages = append(messages, openai.ChatCompletionMessage{
+			Role: openai.ChatMessageRoleSystem, Content: system,
+		})
+	} else {
+		messages = append(messages, openai.ChatCompletionMessage{
+			Role: openai.ChatMessageRoleUser, Content: system,
+		})
 	}
 
 	messages = append(messages, stage.histories...)
@@ -142,30 +153,55 @@ func (v *openaiChatService) RequestChat(ctx context.Context, sreq *StageRequest,
 		Content: user.previousAsrText,
 	})
 
-	model := stage.chatModel
 	maxTokens := 1024
 	temperature := float32(0.9)
 	logger.Tf(ctx, "AIChat is baseURL=%v, org=%v, model=%v, maxTokens=%v, temperature=%v, window=%v, histories=%v, system is %v",
 		v.conf.BaseURL, v.conf.OrgID, model, maxTokens, temperature, stage.chatWindow, len(stage.histories), system)
 
-	client := openai.NewClientWithConfig(v.conf)
-	gptChatStream, err := client.CreateChatCompletionStream(
-		ctx, openai.ChatCompletionRequest{
-			Model:       model,
-			Messages:    messages,
-			Stream:      true,
-			Temperature: temperature,
-			MaxTokens:   maxTokens,
-		},
-	)
-	if err != nil {
-		return errors.Wrapf(err, "create chat")
+	gptReq := openai.ChatCompletionRequest{
+		Model:    model,
+		Messages: messages,
+		// Some model may not support stream.
+		Stream: gptModelSupportStream(model),
+		// Some model may not support MaxTokens.
+		MaxTokens: gptModelSupportMaxTokens(model, maxTokens),
+		// Some model may not support temporature.
+		Temperature: gptModelSupportTemperature(model, temperature),
+	}
+
+	// For OpenAI chat completion, without stream.
+	if !gptModelSupportStream(model) {
+		client := openai.NewClientWithConfig(v.conf)
+		gptChat, err := client.CreateChatCompletion(ctx, gptReq)
+		if err != nil {
+			return errors.Wrapf(err, "create chat")
+		}
+
+		// For sync request, complete the task when finished.
+		defer taskCancel()
+
+		if err := v.handleSentence(ctx,
+			stage, sreq, gptChat.Choices[0].Message.Content, true, nil,
+			func(sentence string) {
+				stage.previousAssitant += sentence + " "
+			},
+		); err != nil {
+			return errors.Wrapf(err, "handle chat")
+		}
+
+		return nil
 	}
 
-	// Wait for AI got the first sentence response.
+	// For OpenAI chat stream. Wait for AI got the first sentence response.
 	aiFirstResponseCtx, aiFirstResponseCancel := context.WithCancel(ctx)
 	defer aiFirstResponseCancel()
 
+	client := openai.NewClientWithConfig(v.conf)
+	gptChatStream, err := client.CreateChatCompletionStream(ctx, gptReq)
+	if err != nil {
+		return errors.Wrapf(err, "create chat")
+	}
+
 	go func() {
 		defer gptChatStream.Close()
 		if err := v.handle(ctx,
@@ -210,8 +246,18 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR
 
 	system := stage.postPrompt
 	system += fmt.Sprintf(" Keep your reply neat, limiting the reply to %v words.", stage.postReplyLimit)
-	messages := []openai.ChatCompletionMessage{
-		{Role: openai.ChatMessageRoleSystem, Content: system},
+	messages := []openai.ChatCompletionMessage{}
+
+	// If not support system message, use User message.
+	model := stage.chatModel
+	if gptModelSupportSystem(model) {
+		messages = append(messages, openai.ChatCompletionMessage{
+			Role: openai.ChatMessageRoleSystem, Content: system,
+		})
+	} else {
+		messages = append(messages, openai.ChatCompletionMessage{
+			Role: openai.ChatMessageRoleUser, Content: system,
+		})
 	}
 
 	messages = append(messages, stage.postHistories...)
@@ -220,27 +266,49 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR
 		Content: stage.previousAssitant,
 	})
 
-	model := stage.postChatModel
 	maxTokens := 1024
 	temperature := float32(0.9)
 	logger.Tf(ctx, "AIPostProcess is baseURL=%v, org=%v, model=%v, maxTokens=%v, temperature=%v, window=%v, histories=%v, system is %v",
 		v.conf.BaseURL, v.conf.OrgID, model, maxTokens, temperature, stage.postChatWindow, len(stage.postHistories), system)
 
+	gptReq := openai.ChatCompletionRequest{
+		Model:    model,
+		Messages: messages,
+		// Some model may not support stream.
+		Stream: gptModelSupportStream(model),
+		// Some model may not support MaxTokens.
+		MaxTokens: gptModelSupportMaxTokens(model, maxTokens),
+		// Some model may not support temporature.
+		Temperature: gptModelSupportTemperature(model, temperature),
+	}
+
+	// For OpenAI chat completion, without stream.
+	if !gptModelSupportStream(model) {
+		client := openai.NewClientWithConfig(v.conf)
+		gptChat, err := client.CreateChatCompletion(ctx, gptReq)
+		if err != nil {
+			return errors.Wrapf(err, "create post-process")
+		}
+
+		if err := v.handleSentence(ctx,
+			stage, sreq, gptChat.Choices[0].Message.Content, true, nil,
+			func(sentence string) {
+				stage.postPreviousAssitant += sentence + " "
+			},
+		); err != nil {
+			return errors.Wrapf(err, "handle post-process")
+		}
+
+		return nil
+	}
+
+	// For OpenAI chat stream. Wait for AI got the first sentence response.
 	client := openai.NewClientWithConfig(v.conf)
-	gptChatStream, err := client.CreateChatCompletionStream(
-		ctx, openai.ChatCompletionRequest{
-			Model:       model,
-			Messages:    messages,
-			Stream:      true,
-			Temperature: temperature,
-			MaxTokens:   maxTokens,
-		},
-	)
+	gptChatStream, err := client.CreateChatCompletionStream(ctx, gptReq)
 	if err != nil {
 		return errors.Wrapf(err, "create post-process")
 	}
 
-	// Wait for AI got the first sentence response.
 	aiFirstResponseCtx, aiFirstResponseCancel := context.WithCancel(ctx)
 	defer aiFirstResponseCancel()
 
@@ -268,6 +336,48 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR
 	return nil
 }
 
+func (v *openaiChatService) handleSentence(
+	ctx context.Context, stage *Stage, sreq *StageRequest,
+	sentence string, firstSentense bool,
+	aiFirstResponseCancel context.CancelFunc, onSentence func(string),
+) error {
+	// Use the sentence for prompt and logging.
+	if onSentence != nil && sentence != "" {
+		onSentence(sentence)
+	}
+
+	filteredSentence := sentence
+	if strings.TrimSpace(sentence) == "" {
+		return nil
+	}
+
+	if firstSentense {
+		if stage.prefix != "" {
+			filteredSentence = fmt.Sprintf("%v %v", stage.prefix, filteredSentence)
+		}
+		if v.onFirstResponse != nil {
+			v.onFirstResponse(ctx, filteredSentence)
+		}
+	}
+
+	segment := NewAnswerSegment(func(segment *AnswerSegment) {
+		segment.request = sreq
+		segment.text = filteredSentence
+		segment.first = firstSentense
+	})
+	stage.ttsWorker.SubmitSegment(ctx, stage, sreq, segment)
+
+	// We have commit the segment to TTS worker, so we can return the response to client and allow
+	// it to query audio segments immediately.
+	if firstSentense && aiFirstResponseCancel != nil {
+		aiFirstResponseCancel()
+	}
+
+	logger.Tf(ctx, "TTS: Commit segment rid=%v, asid=%v, first=%v, sentence is %v",
+		sreq.rid, segment.asid, firstSentense, filteredSentence)
+	return nil
+}
+
 func (v *openaiChatService) handle(
 	ctx context.Context, stage *Stage, user *StageUser, sreq *StageRequest,
 	gptChatStream *openai.ChatCompletionStream, aiFirstResponseCancel context.CancelFunc,
@@ -364,37 +474,8 @@ func (v *openaiChatService) handle(
 		return newSentence
 	}
 
-	commitAISentence := func(sentence string, firstSentense bool) {
-		filteredSentence := sentence
-		if strings.TrimSpace(sentence) == "" {
-			return
-		}
-
-		if firstSentense {
-			if stage.prefix != "" {
-				filteredSentence = fmt.Sprintf("%v %v", stage.prefix, filteredSentence)
-			}
-			if v.onFirstResponse != nil {
-				v.onFirstResponse(ctx, filteredSentence)
-			}
-		}
-
-		segment := NewAnswerSegment(func(segment *AnswerSegment) {
-			segment.request = sreq
-			segment.text = filteredSentence
-			segment.first = firstSentense
-		})
-		stage.ttsWorker.SubmitSegment(ctx, stage, sreq, segment)
-
-		// We have commit the segment to TTS worker, so we can return the response to client and allow
-		// it to query audio segments immediately.
-		if firstSentense {
-			aiFirstResponseCancel()
-		}
-
-		logger.Tf(ctx, "TTS: Commit segment rid=%v, asid=%v, first=%v, sentence is %v",
-			sreq.rid, segment.asid, firstSentense, filteredSentence)
-		return
+	commitAISentence := func(sentence string, firstSentense bool) error {
+		return v.handleSentence(ctx, stage, sreq, sentence, firstSentense, aiFirstResponseCancel, onSentence)
 	}
 
 	var sentence, lastWords string
@@ -413,12 +494,11 @@ func (v *openaiChatService) handle(
 			continue
 		}
 
-		// Use the sentence for prompt and logging.
-		if onSentence != nil && sentence != "" {
-			onSentence(sentence)
-		}
 		// Commit the sentense to TTS worker and callbacks.
-		commitAISentence(sentence, firstSentense)
+		if err = commitAISentence(sentence, firstSentense); err != nil {
+			return errors.Wrapf(err, "commit")
+		}
+
 		// Reset the sentence, because we have committed it.
 		sentence, firstSentense = "", false
 	}
diff --git a/platform/openai.go b/platform/openai.go
new file mode 100644
index 00000000..2345d7a2
--- /dev/null
+++ b/platform/openai.go
@@ -0,0 +1,34 @@
+// Copyright (c) 2022-2024 Winlin
+//
+// SPDX-License-Identifier: MIT
+package main
+
+import "strings"
+
+func gptModelSupportSystem(model string) bool {
+	if strings.HasPrefix(model, "o1-") {
+		return false
+	}
+	return true
+}
+
+func gptModelSupportStream(model string) bool {
+	if strings.HasPrefix(model, "o1-") {
+		return false
+	}
+	return true
+}
+
+func gptModelSupportMaxTokens(model string, maxTokens int) int {
+	if strings.HasPrefix(model, "o1-") {
+		return 0
+	}
+	return maxTokens
+}
+
+func gptModelSupportTemperature(model string, temperature float32) float32 {
+	if strings.HasPrefix(model, "o1-") {
+		return 0.0
+	}
+	return temperature
+}
diff --git a/platform/utils.go b/platform/utils.go
index 5a116582..c8283a7f 100644
--- a/platform/utils.go
+++ b/platform/utils.go
@@ -28,6 +28,7 @@ import (
 
 	"github.com/ossrs/go-oryx-lib/errors"
 	"github.com/ossrs/go-oryx-lib/logger"
+
 	// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
 	"github.com/go-redis/redis/v8"
 	"github.com/golang-jwt/jwt/v4"