Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(http): 为 http engine 添加直接处理响应体的能力,并优化 openai sse 处理逻辑 #40

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions network/http/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strings"
"time"

"github.com/gin-gonic/gin"

"github.com/alioth-center/infrastructure/utils/network"

"github.com/alioth-center/infrastructure/utils/values"
Expand All @@ -20,6 +22,7 @@ type PreprocessedContext[request any, response any] interface {
SetCookieParams(params Params)
SetExtraParams(params Params)
SetRawRequest(raw *http.Request)
SetResponseWriter(res gin.ResponseWriter)
SetRequest(req request)
SetRequestHeader(headers RequestHeader)
}
Expand Down Expand Up @@ -127,6 +130,12 @@ type Context[request any, response any] interface {
// resp (response): The response to be set.
SetResponse(resp response)

// CustomRender returns the custom response writer.
//
// Returns:
// gin.ResponseWriter: The custom response writer.
CustomRender() gin.ResponseWriter

// ResponseHeaders returns the response headers.
//
// Returns:
Expand Down Expand Up @@ -202,10 +211,11 @@ type acContext[request any, response any] struct {
cookieParams Params
extraParams Params

idx int
h Chain[request, response]
raw *http.Request
ctx context.Context
idx int
h Chain[request, response]
raw *http.Request
rawRes gin.ResponseWriter
ctx context.Context

req request
resp response
Expand Down Expand Up @@ -249,6 +259,10 @@ func (c *acContext[request, response]) SetRawRequest(raw *http.Request) {
c.raw = raw
}

func (c *acContext[request, response]) SetResponseWriter(res gin.ResponseWriter) {
c.rawRes = res
}

func (c *acContext[request, response]) SetRequest(req request) {
c.req = req
}
Expand Down Expand Up @@ -341,6 +355,10 @@ func (c *acContext[request, response]) SetResponse(resp response) {
c.resp = resp
}

func (c *acContext[request, response]) CustomRender() gin.ResponseWriter {
return c.rawRes
}

func (c *acContext[request, response]) ResponseHeaders() Params {
return c.setHeaders
}
Expand Down
42 changes: 41 additions & 1 deletion network/http/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -147,6 +148,7 @@ type EndPoint[request any, response any] struct {
router Router
chain Chain[request, response]
allowMethods methodList
customRender bool
parsingHeaders map[string]bool
parsingQueries map[string]bool
parsingParams map[string]bool
Expand Down Expand Up @@ -300,6 +302,14 @@ func (ep *EndPoint[request, response]) Serve(ctx *gin.Context) {
}
ctx.AbortWithStatusJSON(http.StatusInternalServerError, errResponse)
return
} else if len(ctx.Errors) > 0 {
errResponse := &FrameworkResponse{
ErrorCode: ErrorCodeInternalErrorOccurred,
ErrorMessage: ctx.Errors.String(),
RequestID: tid,
}
ctx.AbortWithStatusJSON(http.StatusInternalServerError, errResponse)
return
}

// set response
Expand All @@ -309,7 +319,26 @@ func (ep *EndPoint[request, response]) Serve(ctx *gin.Context) {
for _, cookie := range context.ResponseSetCookies() {
ctx.SetCookie(cookie.Name, cookie.Value, cookie.MaxAge, cookie.Path, cookie.Domain, cookie.Secure, cookie.HttpOnly)
}
ctx.JSON(context.StatusCode(), context.Response())

// write response
if ep.customRender {
// enable no render or customRender options, write nothing to response writer
ctx.Status(context.StatusCode())
return
}

// disable custom render options, use default json render
outData, encodeErr := json.Marshal(context.Response())
if encodeErr != nil {
errResponse := &FrameworkResponse{
ErrorCode: ErrorCodeInternalErrorOccurred,
ErrorMessage: values.BuildStrings("internal error: ", encodeErr.Error()),
RequestID: tid,
}
ctx.AbortWithStatusJSON(http.StatusInternalServerError, errResponse)
return
}
ctx.Data(context.StatusCode(), ContentTypeJson, outData)
}

// EndPointOptions is the options for EndPoint.
Expand Down Expand Up @@ -453,6 +482,12 @@ func WithAllowedMethodsOpts[request any, response any](methods ...Method) EndPoi
}
}

func WithCustomRender[request any, response any](enable bool) EndPointOptions[request, response] {
return func(ep *EndPoint[request, response]) {
ep.customRender = enable
}
}

// NewEndPointWithOpts creates a new EndPoint with options.
// example:
//
Expand Down Expand Up @@ -639,6 +674,11 @@ func (eb *EndPointBuilder[request, response]) SetCustomPreprocessors(preprocesso
return eb
}

func (eb *EndPointBuilder[request, response]) SetCustomRender(enable bool) *EndPointBuilder[request, response] {
eb.options = append(eb.options, WithCustomRender[request, response](enable))
return eb
}

func (eb *EndPointBuilder[request, response]) Build() *EndPoint[request, response] {
return NewEndPointWithOpts[request, response](eb.options...)
}
Expand Down
1 change: 1 addition & 0 deletions network/http/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func NewEngine(base string) *Engine {
middlewares: []gin.HandlerFunc{},
}

e.core.Use(gin.Recovery())
e.core.Use(e.traceContext)
e.core.NoRoute(e.defaultHandler)
e.core.NoMethod(e.defaultHandler)
Expand Down
6 changes: 5 additions & 1 deletion network/http/preprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func CheckRequestCookiesPreprocessor[request any, response any](endpoint *EndPoi
dest.SetCookieParams(cookies)
}

func CheckRequestBodyPreprocessor[request any, response any](_ *EndPoint[request, response], origin *gin.Context, dest PreprocessedContext[request, response]) {
func CheckRequestBodyPreprocessor[request any, response any](ep *EndPoint[request, response], origin *gin.Context, dest PreprocessedContext[request, response]) {
// checking chain is aborted, no need to check
if origin.IsAborted() {
return
Expand Down Expand Up @@ -225,4 +225,8 @@ func CheckRequestBodyPreprocessor[request any, response any](_ *EndPoint[request
}

dest.SetRequest(requestBody)

if ep.customRender {
dest.SetResponseWriter(origin.Writer)
}
}
12 changes: 12 additions & 0 deletions network/http/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ func (p *simpleParser) BindCookie(fields ...string) (cookies map[string]*http.Co
return cookies
}

func (p *simpleParser) BindServerSentEvent(fn func(event *ServerSentEvent)) {
if p.raw == nil || p.raw.Body == nil {
return
}

for event := range ParseServerSentEventFromBody(p.raw.Body, 4096, 256) {
fn(event)
}

_ = p.raw.Body.Close()
}

func NewSimpleResponseParser(r *http.Response) ResponseParser {
// read response body
buf := &bytes.Buffer{}
Expand Down
132 changes: 132 additions & 0 deletions network/http/sse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package http

import (
"bufio"
"bytes"
"errors"
"io"
)

type ServerSentEvent struct {
ID []byte `json:"id,omitempty"`
Data []byte `json:"data,omitempty"`
Event []byte `json:"event,omitempty"`
Retry []byte `json:"retry,omitempty"`
Comment []byte `json:"comment,omitempty"`
Error error `json:"-"`
}

func ParseServerSentEventFromBody(body io.Reader, readBufferSize, receiverBufferSize int) <-chan *ServerSentEvent {
ch := make(chan *ServerSentEvent, receiverBufferSize)
dec := &decoder{events: ch}
return dec.decode(body, readBufferSize)
}

// decoder is a custom Server-Sent Events from gin-contrib/sse/decoder. see [gin-contrib/sse]
//
// [gin-contrib/sse]: https://github.com/gin-contrib/sse/blob/master/sse-decoder.go
type decoder struct {
events chan *ServerSentEvent
}

func (d *decoder) dispatchEvent(event ServerSentEvent, data []byte) {
dataLength := len(data)
if dataLength > 0 {
// If the data buffer's last character is a U+000A LINE FEED (LF) character, then remove the last character from the data buffer.
data = data[:dataLength-1]
dataLength--
}
if dataLength == 0 && len(event.Event) == 0 {
return
}
if len(event.Event) == 0 {
event.Event = []byte("message")
}
event.Data = data

d.events <- &event
}

func (d *decoder) decode(r io.Reader, readBufferSize int) <-chan *ServerSentEvent {
go func() {
reader := bufio.NewReaderSize(r, readBufferSize)
for {
line, readErr := reader.ReadBytes('\n')
if readErr != nil && !errors.Is(readErr, io.EOF) {
currentEvent := ServerSentEvent{Error: readErr}
d.events <- &currentEvent
break
}

if readErr != nil && errors.Is(readErr, io.EOF) {
close(d.events)
return
}

currentEvent := ServerSentEvent{}
dataBuffer := &bytes.Buffer{}
if len(line) == 0 {
// If the line is empty (a blank line). Dispatch the event.
d.dispatchEvent(currentEvent, dataBuffer.Bytes())

// reset current event and data buffer
currentEvent = ServerSentEvent{}
dataBuffer.Reset()
continue
}
if line[0] == byte(':') {
// If the line starts with a U+003A COLON character (:), ignore the line.
continue
}

var field, value []byte
colonIndex := bytes.IndexRune(line, ':')
if colonIndex != -1 {
// If the line contains a U+003A COLON character character (:)
// Collect the characters on the line before the first U+003A COLON character (:),
// and let field be that string.
field = line[:colonIndex]
// Collect the characters on the line after the first U+003A COLON character (:),
// and let value be that string.
value = line[colonIndex+1:]
// If value starts with a single U+0020 SPACE character, remove it from value.
if len(value) > 0 && value[0] == ' ' {
value = value[1:]
}
} else {
// Otherwise, the string is not empty but does not contain a U+003A COLON character character (:)
// Use the whole line as the field name, and the empty string as the field value.
field = line
value = []byte{}
}
// The steps to process the field given a field name and a field value depend on the field name,
// as given in the following list. Field names must be compared literally,
// with no case folding performed.
switch string(field) {
case "event":
// Set the event name buffer to field value.
currentEvent.Event = value
case "id":
// Set the event stream's last event ID to the field value.
currentEvent.ID = value
case "retry":
// If the field value consists of only characters in the range U+0030 DIGIT ZERO (0) to U+0039 DIGIT NINE (9),
// then interpret the field value as an integer in base ten, and set the event stream's reconnection time to that integer.
// Otherwise, ignore the field.
currentEvent.ID = value
case "data":
// Append the field value to the data buffer,
dataBuffer.Write(value)
// then append a single U+000A LINE FEED (LF) character to the data buffer.
dataBuffer.WriteString("\n")
default:
// Otherwise. The field is ignored.
continue
}

d.dispatchEvent(currentEvent, dataBuffer.Bytes())
}
}()

return d.events
}
4 changes: 2 additions & 2 deletions network/http/unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,8 +1042,8 @@ func TestEngine(t *testing.T) {

// Test registerEndpoints
engine.registerEndpoints()
if len(engine.core.Handlers) != 3 { // 1 for traceContext, 2 for added middlewares
t.Fatalf("Expected 3 handlers, got %d", len(engine.core.Handlers))
if len(engine.core.Handlers) != 4 { // 1 for traceContext, 2 for added middlewares
t.Fatalf("Expected 4 handlers, got %d", len(engine.core.Handlers))
}

// Test traceContext
Expand Down
18 changes: 4 additions & 14 deletions thirdparty/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/alioth-center/infrastructure/logger"
"github.com/alioth-center/infrastructure/network/http"
"github.com/alioth-center/infrastructure/utils/values"
"github.com/gin-contrib/sse"
"github.com/pandodao/tokenizer-go"
)

Expand Down Expand Up @@ -167,25 +166,16 @@ func (c client) CompleteStreamingChat(ctx context.Context, req CompleteChatReque
go func(events chan StreamingReplyObject, body io.ReadCloser) {
defer close(events)

decoded, decodeErr := sse.Decode(body)
if decodeErr != nil {
c.logger.Error(logger.NewFields(ctx).WithMessage("decode complete chat response error").WithData(decodeErr))
return
}
for _, event := range decoded {
for event := range http.ParseServerSentEventFromBody(body, 4096, 256) {
reply := StreamingReplyObject{}
payload, ok := event.Data.(string)
if !ok {
c.logger.Error(logger.NewFields(ctx).WithMessage("convert complete chat response data error").WithData(map[string]any{"event": event}))
continue
}
payload := event.Data

// end of the conversation
if payload == "[DONE]" {
if string(payload) == "[DONE]" {
break
}

if unmarshalErr := json.Unmarshal(json.RawMessage(payload), &reply); unmarshalErr != nil {
if unmarshalErr := json.Unmarshal(payload, &reply); unmarshalErr != nil {
c.logger.Error(logger.NewFields(ctx).WithMessage("unmarshal complete chat response error").WithData(map[string]any{"error": unmarshalErr, "event": event}))
continue
}
Expand Down