From 6eb5fb35bf9e5d0039f5f092be63ffb5fc3cdf14 Mon Sep 17 00:00:00 2001 From: "Gyarbij (Chono N)" <49493993+Gyarbij@users.noreply.github.com> Date: Fri, 20 Dec 2024 15:15:46 +0100 Subject: [PATCH] feat: refactor proxy handling and update Vertex AI model mappings and added streaming support. --- main.go | 18 ++-- pkg/vertex/proxy.go | 244 +++++++++++++++++++++++++++++--------------- 2 files changed, 169 insertions(+), 93 deletions(-) diff --git a/main.go b/main.go index 0b32af1..2996f6c 100644 --- a/main.go +++ b/main.go @@ -331,31 +331,27 @@ func handleProxy(c *gin.Context) { return } - var server http.Handler - // Choose the proxy based on ProxyMode or specific environment variables switch ProxyMode { case "azure": - server = azure.NewOpenAIReverseProxy() + server := azure.NewOpenAIReverseProxy() + server.ServeHTTP(c.Writer, c.Request) case "google": google.HandleGoogleAIProxy(c) - return // Add this return statement case "vertex": - server = vertex.NewVertexAIReverseProxy() + vertex.HandleVertexAIProxy(c) // Call HandleVertexAIProxy directly default: // Default to Azure if not specified, but only if the endpoint is set if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" { - server = azure.NewOpenAIReverseProxy() + server := azure.NewOpenAIReverseProxy() + server.ServeHTTP(c.Writer, c.Request) } else { // If no endpoint is configured, default to OpenAI - server = openai.NewOpenAIReverseProxy() + server := openai.NewOpenAIReverseProxy() + server.ServeHTTP(c.Writer, c.Request) } } - if ProxyMode != "google" { - server.ServeHTTP(c.Writer, c.Request) - } - if c.Writer.Header().Get("Content-Type") == "text/event-stream" { if _, err := c.Writer.Write([]byte("\n")); err != nil { log.Printf("rewrite response error: %v", err) diff --git a/pkg/vertex/proxy.go b/pkg/vertex/proxy.go index 12ce603..a10e279 100644 --- a/pkg/vertex/proxy.go +++ b/pkg/vertex/proxy.go @@ -1,15 +1,19 @@ package vertex import ( + "context" "encoding/json" "fmt" "io" "log" "net/http" - "net/http/httputil" - "net/url" - "os/exec" + "os" "strings" + + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/iterator" + "google.golang.org/api/option" ) var ( @@ -18,10 +22,10 @@ var ( VertexAIAPIVersion = "v1" VertexAILocation = "us-central1" VertexAIModelMapper = map[string]string{ - "chat-bison": "chat-bison@001", - "text-bison": "text-bison@001", - "embedding-gecko": "textembedding-gecko@001", - "embedding-gecko-multilingual": "textembedding-gecko-multilingual@001", + "chat-bison": "chat-bison@002", + "text-bison": "text-bison@002", + "embedding-gecko": "textembedding-gecko@003", + "embedding-gecko-multilingual": "textembedding-gecko-multilingual@003", } ) @@ -38,89 +42,166 @@ func Init(projectID string) { log.Printf("Vertex AI initialized with Project ID: %s", projectID) } -func NewVertexAIReverseProxy() *httputil.ReverseProxy { - config := &VertexAIConfig{ - ProjectID: VertexAIProjectID, - Endpoint: VertexAIEndpoint, - APIVersion: VertexAIAPIVersion, - Location: VertexAILocation, - ModelMapper: VertexAIModelMapper, +func HandleVertexAIProxy(c *gin.Context) { + if VertexAIProjectID == "" { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Vertex AI Project ID not set"}) + return } - return newVertexAIReverseProxy(config) -} + ctx := context.Background() -func newVertexAIReverseProxy(config *VertexAIConfig) *httputil.ReverseProxy { - director := func(req *http.Request) { - originalURL := req.URL.String() - model := getModelFromRequest(req) + // Use the GOOGLE_APPLICATION_CREDENTIALS environment variable to set the credentials + creds := option.WithCredentialsFile(os.Getenv("GOOGLE_APPLICATION_CREDENTIALS")) + client, err := genai.NewClient(ctx, creds) + if err != nil { + log.Printf("Error creating Vertex AI client: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create Vertex AI client"}) + return + } + defer client.Close() - // Map the model name if necessary - if mappedModel, ok := config.ModelMapper[strings.ToLower(model)]; ok { - model = mappedModel - } + modelName := getModelFromRequestBody(c.Request) + if mappedModel, ok := VertexAIModelMapper[strings.ToLower(modelName)]; ok { + modelName = mappedModel + } - // Construct the new URL - targetURL := fmt.Sprintf("https://%s/%s/projects/%s/locations/%s/publishers/google/models/%s:predict", config.Endpoint, config.APIVersion, config.ProjectID, config.Location, model) - target, err := url.Parse(targetURL) - if err != nil { - log.Printf("Error parsing target URL: %v", err) - return + model := client.GenerativeModel(modelName) + + // Handle chat/completions + if strings.HasSuffix(c.Request.URL.Path, "/chat/completions") { + handleChatCompletion(c, model) + } else { + c.JSON(http.StatusNotFound, gin.H{"error": "Invalid endpoint for Vertex AI"}) + } +} + +func getModelFromRequestBody(req *http.Request) string { + body, _ := io.ReadAll(req.Body) + req.Body = io.NopCloser(strings.NewReader(string(body))) // Restore the body + var data map[string]interface{} + if err := json.Unmarshal(body, &data); err == nil { + if model, ok := data["model"].(string); ok { + return model } + } + return "" +} + +func handleChatCompletion(c *gin.Context, model *genai.GenerativeModel) { + var req struct { + Messages []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"messages"` + Stream *bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + } + + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"}) + return + } + + cs := model.StartChat() + cs.History = []*genai.Content{} - // Set the target - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host - req.URL.Path = target.Path + for _, msg := range req.Messages { + cs.History = append(cs.History, &genai.Content{ + Parts: []genai.Part{ + genai.Text(msg.Content), + }, + Role: msg.Role, + }) + } - // Set Authorization header using Google Application Default Credentials (ADC) - token, err := getAccessToken() + // Set advanced parameters if provided + if req.Temperature != nil { + model.SetTemperature(float32(*req.Temperature)) + } + if req.TopP != nil { + model.SetTopP(float32(*req.TopP)) + } + if req.TopK != nil { + model.SetTopK(int32(*req.TopK)) + } + + // Handle streaming if requested + if req.Stream != nil && *req.Stream { + iter := cs.SendMessageStream(context.Background(), genai.Text(req.Messages[len(req.Messages)-1].Content)) + c.Stream(func(w io.Writer) bool { + resp, err := iter.Next() + if err == iterator.Done { + return false + } + if err != nil { + log.Printf("Error generating content: %v", err) + c.SSEvent("error", "Failed to generate content") + return false + } + + // Convert each response to OpenAI format and send as SSE + openaiResp := convertToOpenAIResponseStream(resp) + c.SSEvent("message", openaiResp) + return true + }) + } else { + // Use SendMessage for a single response + resp, err := cs.SendMessage(context.Background(), genai.Text(req.Messages[len(req.Messages)-1].Content)) if err != nil { - log.Printf("Error getting access token: %v", err) + log.Printf("Error generating content: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate content"}) return } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - log.Printf("proxying request %s -> %s", originalURL, req.URL.String()) + // Convert the response to OpenAI format + openaiResp := convertToOpenAIResponse(resp) + c.JSON(http.StatusOK, openaiResp) } - - return &httputil.ReverseProxy{Director: director} } -func getModelFromRequest(req *http.Request) string { - // Check the URL path for the model - parts := strings.Split(req.URL.Path, "/") - for i, part := range parts { - if part == "models" && i+1 < len(parts) { - return parts[i+1] +// Helper function to convert a single response to OpenAI format (for streaming) +func convertToOpenAIResponseStream(resp *genai.GenerateContentResponse) map[string]interface{} { + var parts []string + for _, candidate := range resp.Candidates { + for _, part := range candidate.Content.Parts { + parts = append(parts, fmt.Sprintf("%v", part)) } } - // If not found in the path, try to get it from the request body - if req.Body != nil { - body, _ := io.ReadAll(req.Body) - req.Body = io.NopCloser(strings.NewReader(string(body))) // Restore the body - var data map[string]interface{} - if err := json.Unmarshal(body, &data); err == nil { - if model, ok := data["model"].(string); ok { - return model - } - } + return map[string]interface{}{ + "object": "chat.completion.chunk", + "choices": []map[string]interface{}{ + { + "index": 0, + "delta": map[string]interface{}{ + "role": "assistant", + "content": strings.Join(parts, ""), + }, + "finish_reason": "stop", + }, + }, } - - return "" } -func getAccessToken() (string, error) { - // Use Application Default Credentials (ADC) to get an access token - // Ensure that your environment is set up with ADC, e.g., by running: - // gcloud auth application-default login - // Or by setting the GOOGLE_APPLICATION_CREDENTIALS environment variable - output, err := exec.Command("gcloud", "auth", "print-access-token").Output() - if err != nil { - return "", fmt.Errorf("failed to get access token: %v", err) +// Helper function to convert a single response to OpenAI format (for non-streaming) +func convertToOpenAIResponse(resp *genai.GenerateContentResponse) map[string]interface{} { + var choices []map[string]interface{} + for _, candidate := range resp.Candidates { + choices = append(choices, map[string]interface{}{ + "index": candidate.Index, + "message": map[string]interface{}{ + "role": "model", + "content": fmt.Sprintf("%v", candidate.Content.Parts), + }, + }) + } + + return map[string]interface{}{ + "object": "chat.completion", + "choices": choices, } - return strings.TrimSpace(string(output)), nil } type Model struct { @@ -156,10 +237,13 @@ func FetchVertexAIModels() ([]Model, error) { return nil, fmt.Errorf("Vertex AI Project ID not set") } - token, err := getAccessToken() + ctx := context.Background() + creds := option.WithCredentialsFile(os.Getenv("GOOGLE_APPLICATION_CREDENTIALS")) + client, err := genai.NewClient(ctx, creds) if err != nil { - return nil, fmt.Errorf("failed to get access token: %v", err) + return nil, fmt.Errorf("failed to create Vertex AI client: %v", err) } + defer client.Close() url := fmt.Sprintf("https://%s/%s/projects/%s/locations/%s/publishers/google/models", VertexAIEndpoint, VertexAIAPIVersion, VertexAIProjectID, VertexAILocation) req, err := http.NewRequest("GET", url, nil) @@ -167,13 +251,11 @@ func FetchVertexAIModels() ([]Model, error) { return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - - client := &http.Client{} - resp, err := client.Do(req) + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -186,7 +268,6 @@ func FetchVertexAIModels() ([]Model, error) { Name string `json:"name"` DisplayName string `json:"displayName"` Description string `json:"description"` - // Add other relevant fields if needed } `json:"models"` } @@ -196,22 +277,21 @@ func FetchVertexAIModels() ([]Model, error) { var models []Model for _, m := range vertexModels.Models { - // Extract model ID from the name field (e.g., "publishers/google/models/chat-bison") parts := strings.Split(m.Name, "/") modelID := parts[len(parts)-1] models = append(models, Model{ - ID: modelID, - Object: "model", + ID: modelID, + Object: "model", + Name: m.Name, + Description: m.Description, + LifecycleStatus: "active", // You might need to adjust this based on actual Vertex AI model data + Status: "ready", // You might need to adjust this based on actual Vertex AI model data Capabilities: Capabilities{ Completion: true, ChatCompletion: strings.Contains(modelID, "chat"), Embeddings: strings.Contains(modelID, "embedding"), }, - LifecycleStatus: "active", - Status: "ready", - Name: m.Name, - Description: m.Description, }) }