From 439057979b22c6e6c4d6d388c2a811ee13fbcda3 Mon Sep 17 00:00:00 2001 From: Max Holland Date: Wed, 16 Oct 2024 19:29:41 +0100 Subject: [PATCH] Refactor to remove the repetition (#3203) * Refactor to remove the repetition * Fix debug logging --- server/ai_mediaserver.go | 250 ++++----------------------------------- server/ai_process.go | 8 ++ 2 files changed, 31 insertions(+), 227 deletions(-) diff --git a/server/ai_mediaserver.go b/server/ai_mediaserver.go index d8bca64a5..128f9aade 100644 --- a/server/ai_mediaserver.go +++ b/server/ai_mediaserver.go @@ -64,94 +64,52 @@ func startAIMediaServer(ls *LivepeerServer) error { openapi3filter.RegisterBodyDecoder("image/png", openapi3filter.FileBodyDecoder) - ls.HTTPMux.Handle("/text-to-image", oapiReqValidator(ls.TextToImage())) - ls.HTTPMux.Handle("/image-to-image", oapiReqValidator(ls.ImageToImage())) - ls.HTTPMux.Handle("/upscale", oapiReqValidator(ls.Upscale())) + ls.HTTPMux.Handle("/text-to-image", oapiReqValidator(handle(ls, jsonDecoder[worker.GenTextToImageJSONRequestBody], processTextToImage))) + ls.HTTPMux.Handle("/image-to-image", oapiReqValidator(handle(ls, multipartDecoder[worker.GenImageToImageMultipartRequestBody], processImageToImage))) + ls.HTTPMux.Handle("/upscale", oapiReqValidator(handle(ls, multipartDecoder[worker.GenUpscaleMultipartRequestBody], processUpscale))) ls.HTTPMux.Handle("/image-to-video", oapiReqValidator(ls.ImageToVideo())) ls.HTTPMux.Handle("/image-to-video/result", ls.ImageToVideoResult()) - ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText())) + ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(handle(ls, multipartDecoder[worker.GenAudioToTextMultipartRequestBody], processAudioToText))) ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM())) - ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2())) + ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(handle(ls, multipartDecoder[worker.GenSegmentAnything2MultipartRequestBody], processSegmentAnything2))) return nil } -func (ls *LivepeerServer) TextToImage() http.Handler { +// Decoder for JSON requests +func jsonDecoder[T any](req *T, r *http.Request) error { + return json.NewDecoder(r.Body).Decode(req) +} + +// Decoder for Multipart requests +func multipartDecoder[T any](req *T, r *http.Request) error { + multiRdr, err := r.MultipartReader() + if err != nil { + return err + } + return runtime.BindMultipart(req, *multiRdr) +} + +func handle[I, O any](ls *LivepeerServer, decoderFunc func(*I, *http.Request) error, processorFunc func(context.Context, aiRequestParams, I) (O, error)) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) requestID := string(core.RandomManifestID()) ctx = clog.AddVal(ctx, "request_id", requestID) - var req worker.GenTextToImageJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - - clog.V(common.VERBOSE).Infof(ctx, "Received TextToImage request prompt=%v model_id=%v", req.Prompt, *req.ModelId) - params := aiRequestParams{ node: ls.LivepeerNode, os: drivers.NodeStorage.NewSession(requestID), sessManager: ls.AISessionManager, } - start := time.Now() - resp, err := processTextToImage(ctx, params, req) - if err != nil { - var serviceUnavailableErr *ServiceUnavailableError - var badRequestErr *BadRequestError - if errors.As(err, &serviceUnavailableErr) { - respondJsonError(ctx, w, err, http.StatusServiceUnavailable) - return - } - if errors.As(err, &badRequestErr) { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - respondJsonError(ctx, w, err, http.StatusInternalServerError) - return - } - - took := time.Since(start) - clog.Infof(ctx, "Processed TextToImage request prompt=%v model_id=%v took=%v", req.Prompt, *req.ModelId, took) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) - }) -} - -func (ls *LivepeerServer) ImageToImage() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - remoteAddr := getRemoteAddr(r) - ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) - requestID := string(core.RandomManifestID()) - ctx = clog.AddVal(ctx, "request_id", requestID) - - multiRdr, err := r.MultipartReader() - if err != nil { + var req I + if err := decoderFunc(&req, r); err != nil { respondJsonError(ctx, w, err, http.StatusBadRequest) return } - var req worker.GenImageToImageMultipartRequestBody - if err := runtime.BindMultipart(&req, *multiRdr); err != nil { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - - clog.V(common.VERBOSE).Infof(ctx, "Received ImageToImage request imageSize=%v prompt=%v model_id=%v", req.Image.FileSize(), req.Prompt, *req.ModelId) - - params := aiRequestParams{ - node: ls.LivepeerNode, - os: drivers.NodeStorage.NewSession(requestID), - sessManager: ls.AISessionManager, - } - - start := time.Now() - resp, err := processImageToImage(ctx, params, req) + resp, err := processorFunc(ctx, params, req) if err != nil { var serviceUnavailableErr *ServiceUnavailableError var badRequestErr *BadRequestError @@ -167,9 +125,6 @@ func (ls *LivepeerServer) ImageToImage() http.Handler { return } - took := time.Since(start) - clog.V(common.VERBOSE).Infof(ctx, "Processed ImageToImage request imageSize=%v prompt=%v model_id=%v took=%v", req.Image.FileSize(), req.Prompt, *req.ModelId, took) - w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(resp) @@ -290,112 +245,6 @@ func (ls *LivepeerServer) ImageToVideo() http.Handler { }) } -func (ls *LivepeerServer) Upscale() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - remoteAddr := getRemoteAddr(r) - ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) - requestID := string(core.RandomManifestID()) - ctx = clog.AddVal(ctx, "request_id", requestID) - - multiRdr, err := r.MultipartReader() - if err != nil { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - - var req worker.GenUpscaleMultipartRequestBody - if err := runtime.BindMultipart(&req, *multiRdr); err != nil { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - - clog.V(common.VERBOSE).Infof(ctx, "Received Upscale request imageSize=%v prompt=%v model_id=%v", req.Image.FileSize(), req.Prompt, *req.ModelId) - - params := aiRequestParams{ - node: ls.LivepeerNode, - os: drivers.NodeStorage.NewSession(requestID), - sessManager: ls.AISessionManager, - } - - start := time.Now() - resp, err := processUpscale(ctx, params, req) - if err != nil { - var serviceUnavailableErr *ServiceUnavailableError - var badRequestErr *BadRequestError - if errors.As(err, &serviceUnavailableErr) { - respondJsonError(ctx, w, err, http.StatusServiceUnavailable) - return - } - if errors.As(err, &badRequestErr) { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - respondJsonError(ctx, w, err, http.StatusInternalServerError) - return - } - - took := time.Since(start) - clog.V(common.VERBOSE).Infof(ctx, "Processed Upscale request imageSize=%v prompt=%v model_id=%v took=%v", req.Image.FileSize(), req.Prompt, *req.ModelId, took) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) - }) -} - -func (ls *LivepeerServer) AudioToText() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - remoteAddr := getRemoteAddr(r) - ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) - requestID := string(core.RandomManifestID()) - ctx = clog.AddVal(ctx, "request_id", requestID) - - multiRdr, err := r.MultipartReader() - if err != nil { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - - var req worker.GenAudioToTextMultipartRequestBody - if err := runtime.BindMultipart(&req, *multiRdr); err != nil { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - - clog.V(common.VERBOSE).Infof(ctx, "Received AudioToText request audioSize=%v model_id=%v", req.Audio.FileSize(), *req.ModelId) - - params := aiRequestParams{ - node: ls.LivepeerNode, - os: drivers.NodeStorage.NewSession(requestID), - sessManager: ls.AISessionManager, - } - - start := time.Now() - resp, err := processAudioToText(ctx, params, req) - if err != nil { - var serviceUnavailableErr *ServiceUnavailableError - var badRequestErr *BadRequestError - if errors.As(err, &serviceUnavailableErr) { - respondJsonError(ctx, w, err, http.StatusServiceUnavailable) - return - } - if errors.As(err, &badRequestErr) { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - respondJsonError(ctx, w, err, http.StatusInternalServerError) - return - } - - took := time.Since(start) - clog.V(common.VERBOSE).Infof(ctx, "Processed AudioToText request model_id=%v took=%v", *req.ModelId, took) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) - }) -} - func (ls *LivepeerServer) LLM() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) @@ -463,59 +312,6 @@ func (ls *LivepeerServer) LLM() http.Handler { }) } -func (ls *LivepeerServer) SegmentAnything2() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - remoteAddr := getRemoteAddr(r) - ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr) - requestID := string(core.RandomManifestID()) - ctx = clog.AddVal(ctx, "request_id", requestID) - - multiRdr, err := r.MultipartReader() - if err != nil { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - - var req worker.GenSegmentAnything2MultipartRequestBody - if err := runtime.BindMultipart(&req, *multiRdr); err != nil { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - - clog.V(common.VERBOSE).Infof(ctx, "Received SegmentAnything2 request; image_size=%v model_id=%v", req.Image.FileSize(), *req.ModelId) - - params := aiRequestParams{ - node: ls.LivepeerNode, - os: drivers.NodeStorage.NewSession(requestID), - sessManager: ls.AISessionManager, - } - - start := time.Now() - resp, err := processSegmentAnything2(ctx, params, req) - if err != nil { - var serviceUnavailableErr *ServiceUnavailableError - var badRequestErr *BadRequestError - if errors.As(err, &serviceUnavailableErr) { - respondJsonError(ctx, w, err, http.StatusServiceUnavailable) - return - } - if errors.As(err, &badRequestErr) { - respondJsonError(ctx, w, err, http.StatusBadRequest) - return - } - respondJsonError(ctx, w, err, http.StatusInternalServerError) - return - } - - took := time.Since(start) - clog.V(common.VERBOSE).Infof(ctx, "Processed SegmentAnything2 request model_id=%v took=%v", *req.ModelId, took) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(resp) - }) -} - func (ls *LivepeerServer) ImageToVideoResult() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { remoteAddr := getRemoteAddr(r) diff --git a/server/ai_process.go b/server/ai_process.go index f39e321a7..e088c24c9 100644 --- a/server/ai_process.go +++ b/server/ai_process.go @@ -986,6 +986,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitTextToImage(ctx, params, sess, v) } + ctx = clog.AddVal(ctx, "prompt", v.Prompt) case worker.GenImageToImageMultipartRequestBody: cap = core.Capability_ImageToImage modelID = defaultImageToImageModelID @@ -995,6 +996,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitImageToImage(ctx, params, sess, v) } + ctx = clog.AddVal(ctx, "prompt", v.Prompt) case worker.GenImageToVideoMultipartRequestBody: cap = core.Capability_ImageToVideo modelID = defaultImageToVideoModelID @@ -1013,6 +1015,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitUpscale(ctx, params, sess, v) } + ctx = clog.AddVal(ctx, "prompt", v.Prompt) case worker.GenAudioToTextMultipartRequestBody: cap = core.Capability_AudioToText modelID = defaultAudioToTextModelID @@ -1031,6 +1034,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) { return submitLLM(ctx, params, sess, v) } + ctx = clog.AddVal(ctx, "prompt", v.Prompt) case worker.GenSegmentAnything2MultipartRequestBody: cap = core.Capability_SegmentAnything2 modelID = defaultSegmentAnything2ModelID @@ -1046,6 +1050,10 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface capName := cap.String() ctx = clog.AddVal(ctx, "capability", capName) + clog.V(common.VERBOSE).Infof(ctx, "Received AI request model_id=%s", modelID) + start := time.Now() + defer clog.Infof(ctx, "Processed AI request model_id=%v took=%v", modelID, time.Since(start)) + var resp interface{} cctx, cancel := context.WithTimeout(ctx, processingRetryTimeout)