Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lipsync Real3dPortrait #3201

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,20 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_SegmentAnything2, config.ModelID, autoPrice)
}
case "lipsync":
_, ok := capabilityConstraints[core.Capability_Lipsync]
if !ok {
aiCaps = append(aiCaps, core.Capability_Lipsync)
capabilityConstraints[core.Capability_Lipsync] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}

capabilityConstraints[core.Capability_Lipsync].Models[config.ModelID] = modelConstraint

if *cfg.Network != "offchain" {
n.SetBasePriceForCap("default", core.Capability_Lipsync, config.ModelID, autoPrice)
}
}

if len(aiCaps) > 0 {
Expand Down
1 change: 1 addition & 0 deletions core/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type AI interface {
AudioToText(context.Context, worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LLM(context.Context, worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(context.Context, worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Lipsync(context.Context, worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error)
Warm(context.Context, string, string, worker.RunnerEndpoint, worker.OptimizationFlags) error
Stop(context.Context) error
HasCapacity(pipeline, modelID string) bool
Expand Down
3 changes: 3 additions & 0 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const (
Capability_AudioToText Capability = 31
Capability_SegmentAnything2 Capability = 32
Capability_LLM Capability = 33
Capability_Lipsync Capability = 34
)

var CapabilityNameLookup = map[Capability]string{
Expand Down Expand Up @@ -117,6 +118,7 @@ var CapabilityNameLookup = map[Capability]string{
Capability_AudioToText: "Audio to text",
Capability_SegmentAnything2: "Segment anything 2",
Capability_LLM: "Large language model",
Capability_Lipsync: "Lipsync",
}

var CapabilityTestLookup = map[Capability]CapabilityTest{
Expand Down Expand Up @@ -208,6 +210,7 @@ func OptionalCapabilities() []Capability {
Capability_Upscale,
Capability_AudioToText,
Capability_SegmentAnything2,
Capability_Lipsync,
}
}

Expand Down
7 changes: 7 additions & 0 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ func (orch *orchestrator) SegmentAnything2(ctx context.Context, req worker.GenSe
return orch.node.SegmentAnything2(ctx, req)
}

func (orch *orchestrator) Lipsync(ctx context.Context, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error) {
return orch.node.Lipsync(ctx, req)
}

func (orch *orchestrator) ProcessPayment(ctx context.Context, payment net.Payment, manifestID ManifestID) error {
if orch.node == nil || orch.node.Recipient == nil {
return nil
Expand Down Expand Up @@ -987,6 +991,9 @@ func (n *LivepeerNode) AudioToText(ctx context.Context, req worker.GenAudioToTex
func (n *LivepeerNode) SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error) {
return n.AIWorker.SegmentAnything2(ctx, req)
}
func (n *LivepeerNode) Lipsync(ctx context.Context, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error) {
return n.AIWorker.Lipsync(ctx, req)
}

func (n *LivepeerNode) imageToVideo(ctx context.Context, req worker.GenImageToVideoMultipartRequestBody) (*worker.ImageResponse, error) {
// We might support generating more than one video in the future (i.e. multiple input images/prompts)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,5 @@ require (
lukechampine.com/blake3 v1.2.1 // indirect
rsc.io/tmplfunc v0.0.3 // indirect
)

replace github.com/livepeer/ai-worker => ../ai-worker
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,6 @@ github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4n
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/libp2p/go-openssl v0.1.0 h1:LBkKEcUv6vtZIQLVTegAil8jbNpJErQ9AnT+bWV+Ooo=
github.com/libp2p/go-openssl v0.1.0/go.mod h1:OiOxwPpL3n4xlenjx2h7AwSGaFSC/KZvf6gNdOBQMtc=
github.com/livepeer/ai-worker v0.7.0 h1:9z5Uz9WvKyQTXiurWim1ewDcVPLzz7EYZEfm2qtLAaw=
github.com/livepeer/ai-worker v0.7.0/go.mod h1:91lMzkzVuwR9kZ0EzXwf+7yVhLaNVmYAfmBtn7t3cQA=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b h1:VQcnrqtCA2UROp7q8ljkh2XA/u0KRgVv0S1xoUvOweE=
github.com/livepeer/go-tools v0.3.6-0.20240130205227-92479de8531b/go.mod h1:hwJ5DKhl+pTanFWl+EUpw1H7ukPO/H+MFpgA7jjshzw=
github.com/livepeer/joy4 v0.1.2-0.20191121080656-b2fea45cbded h1:ZQlvR5RB4nfT+cOQee+WqmaDOgGtP2oDMhcVvR4L0yA=
Expand Down
57 changes: 57 additions & 0 deletions server/ai_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func startAIServer(lp lphttp) error {
lp.transRPC.Handle("/audio-to-text", oapiReqValidator(lp.AudioToText()))
lp.transRPC.Handle("/llm", oapiReqValidator(lp.LLM()))
lp.transRPC.Handle("/segment-anything-2", oapiReqValidator(lp.SegmentAnything2()))
lp.transRPC.Handle("/lipsync", oapiReqValidator(lp.Lipsync()))

return nil
}
Expand Down Expand Up @@ -205,6 +206,38 @@ func (h *lphttp) LLM() http.Handler {
})
}

func (h *lphttp) Lipsync() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
orch := h.orchestrator

remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)

multiRdr, err := r.MultipartReader()
if err != nil {
clog.Errorf(ctx, "Failed to read multipart form: %v", err)
respondWithError(w, err.Error(), http.StatusBadRequest)
return
}

var req worker.GenLipsyncMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
clog.Errorf(ctx, "Failed to bind multipart request: %v", err)
respondWithError(w, err.Error(), http.StatusInternalServerError)
return
}

if req.ModelId == nil || *req.ModelId == "" {
defaultModelId := "parler-tts/parler-tts-large-v1"
req.ModelId = &defaultModelId
} else {
clog.Infof(ctx, "model_id received: %s", *req.ModelId)
}

handleAIRequest(ctx, w, r, orch, req)
})
}

func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, orch Orchestrator, req interface{}) {
payment, err := getPayment(r.Header.Get(paymentHeader))
if err != nil {
Expand Down Expand Up @@ -363,6 +396,30 @@ func handleAIRequest(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}
outPixels = int64(config.Height) * int64(config.Width)
case worker.GenLipsyncMultipartRequestBody:
pipeline = "lipsync"
cap = core.Capability_Lipsync
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context) (interface{}, error) {
return orch.Lipsync(ctx, v)
}
if v.Audio != nil {
outPixels, err = common.CalculateAudioDuration(*v.Audio)
if err != nil {
respondWithError(w, "Unable to calculate duration", http.StatusBadRequest)
return
}
outPixels *= 1000 // Convert to milliseconds
} else {
// TODO: extract method - this is the same as the calcuallation in ai_process.go
textLength := len(*v.TextInput)
// TODO (pschroedl): if TTS is staying in this branch, confirm sane pricing
lipsyncMultiplier := int64(5) // this value is based on a observation that lipsync takes ~5x more compute ( vram/time ) than TTS alone
durationSeconds := float64(textLength) / 13.0 // assuming the average speaking rate is around 13 characters per second
outPixels = int64(durationSeconds * 60) * lipsyncMultiplier
}
default:
respondWithError(w, "Unknown request type", http.StatusBadRequest)
return
Expand Down
55 changes: 55 additions & 0 deletions server/ai_mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func startAIMediaServer(ls *LivepeerServer) error {
ls.HTTPMux.Handle("/audio-to-text", oapiReqValidator(ls.AudioToText()))
ls.HTTPMux.Handle("/llm", oapiReqValidator(ls.LLM()))
ls.HTTPMux.Handle("/segment-anything-2", oapiReqValidator(ls.SegmentAnything2()))
ls.HTTPMux.Handle("/lipsync", oapiReqValidator(ls.Lipsync()))

return nil
}
Expand Down Expand Up @@ -516,6 +517,60 @@ func (ls *LivepeerServer) SegmentAnything2() http.Handler {
})
}

func (ls *LivepeerServer) Lipsync() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
ctx := clog.AddVal(r.Context(), clog.ClientIP, remoteAddr)
requestID := string(core.RandomManifestID())
ctx = clog.AddVal(ctx, "request_id", requestID)

multiRdr, err := r.MultipartReader()
if err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

var req worker.GenLipsyncMultipartRequestBody
if err := runtime.BindMultipart(&req, *multiRdr); err != nil {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}

clog.V(common.VERBOSE).Infof(ctx, "Received Lipsync request; image_size=%v model_id=%v", req.Image.FileSize(), req.ModelId)

params := aiRequestParams{
node: ls.LivepeerNode,
os: drivers.NodeStorage.NewSession(requestID),
sessManager: ls.AISessionManager,
}

start := time.Now()
resp, err := processLipsync(ctx, params, req)
if err != nil {
var serviceUnavailableErr *ServiceUnavailableError
var badRequestErr *BadRequestError
if errors.As(err, &serviceUnavailableErr) {
respondJsonError(ctx, w, err, http.StatusServiceUnavailable)
return
}
if errors.As(err, &badRequestErr) {
respondJsonError(ctx, w, err, http.StatusBadRequest)
return
}
respondJsonError(ctx, w, err, http.StatusInternalServerError)
return
}

took := time.Since(start)
clog.V(common.VERBOSE).Infof(ctx, "Processed Lipsync request model_id=%v took=%v", req.ModelId, took)

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
})
}


func (ls *LivepeerServer) ImageToVideoResult() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteAddr := getRemoteAddr(r)
Expand Down
117 changes: 117 additions & 0 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const defaultUpscaleModelID = "stabilityai/stable-diffusion-x4-upscaler"
const defaultAudioToTextModelID = "openai/whisper-large-v3"
const defaultLLMModelID = "meta-llama/llama-3.1-8B-Instruct"
const defaultSegmentAnything2ModelID = "facebook/sam2-hiera-large"
const defaultLipsyncModelID = "parler-tts/parler-tts-large-v1"

type ServiceUnavailableError struct {
err error
Expand Down Expand Up @@ -971,6 +972,113 @@ func handleNonStreamingResponse(ctx context.Context, body io.ReadCloser, sess *A
return &res, nil
}

func CalculateLipsyncLatencyScore(took time.Duration, outFrames int64) float64 {
if outFrames <= 0 {
return 0
}

return took.Seconds() / float64(outFrames)
}

func processLipsync(ctx context.Context, params aiRequestParams, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error) {
resp, err := processAIRequest(ctx, params, req)
if err != nil {
return nil, err
}

txtResp := resp.(*worker.VideoBinaryResponse)

return txtResp, nil
}

func submitLipsync(ctx context.Context, params aiRequestParams, sess *AISession, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error) {
var buf bytes.Buffer
mw, err := worker.NewLipsyncMultipartWriter(&buf, req)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

client, err := worker.NewClientWithResponses(sess.Transcoder(), worker.WithHTTPClient(httpClient))
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

outFrames := int64(0)
if req.Audio != nil {
durationSeconds, err := common.CalculateAudioDuration(*req.Audio)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "audio-to-text", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

clog.V(common.VERBOSE).Infof(ctx, "Submitting lipsync audio with duration: %d seconds", durationSeconds)

outFrames = int64(durationSeconds * 60) // TODO (pschroedl): validate FPS of lipsync output, confirm sane pricing

} else {
textLength := len(*req.TextInput)
clog.V(common.VERBOSE).Infof(ctx, "Submitting text-to-speech request with text length: %d", textLength)
// TODO (pschroedl): if TTS is staying in this branch, confirm sane pricing
lipsyncMultiplier := int64(5) // this value is based on a observation that lipsync takes ~5x more compute ( vram/time ) than TTS alone
durationSeconds := int64(textLength / 13.0) // assuming the average speaking rate is around 13 characters per second
outFrames = int64(durationSeconds * 60) * lipsyncMultiplier
}

setHeaders, balUpdate, err := prepareAIPayment(ctx, sess, outFrames)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}
defer completeBalanceUpdate(sess.BroadcastSession, balUpdate)

// Send the request and measure the processing time
start := time.Now()
resp, err := client.GenLipsyncWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf, setHeaders)
took := time.Since(start)
if err != nil {
if monitor.Enabled {
monitor.AIRequestError(err.Error(), "lipsync", *req.ModelId, sess.OrchestratorInfo)
}
return nil, err
}

// Check for errors in the response
if resp.JSON200 == nil {
// Handle the case where the response is not a 200 success
return nil, errors.New(strings.TrimSuffix(string(resp.Body), "\n"))
}

// Update the balance as receiving change if relevant
if balUpdate != nil {
balUpdate.Status = ReceivedChange
}

// Calculate the latency score for this lipsync request
sess.LatencyScore = CalculateLipsyncLatencyScore(took, outFrames)

// Log the AI request completion with latency score and pricing
if monitor.Enabled {
var pricePerAIUnit float64
if priceInfo := sess.OrchestratorInfo.GetPriceInfo(); priceInfo != nil && priceInfo.PixelsPerUnit != 0 {
pricePerAIUnit = float64(priceInfo.PricePerUnit) / float64(priceInfo.PixelsPerUnit)
}

monitor.AIRequestFinished(ctx, "lipsync", *req.ModelId, monitor.AIJobInfo{LatencyScore: sess.LatencyScore, PricePerUnit: pricePerAIUnit}, sess.OrchestratorInfo)
}

return resp.JSON200, nil
}

func processAIRequest(ctx context.Context, params aiRequestParams, req interface{}) (interface{}, error) {
var cap core.Capability
var modelID string
Expand Down Expand Up @@ -1040,6 +1148,15 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitSegmentAnything2(ctx, params, sess, v)
}
case worker.GenLipsyncMultipartRequestBody:
cap = core.Capability_Lipsync
modelID = defaultLipsyncModelID
if v.ModelId != nil {
modelID = *v.ModelId
}
submitFn = func(ctx context.Context, params aiRequestParams, sess *AISession) (interface{}, error) {
return submitLipsync(ctx, params, sess, v)
}
default:
return nil, fmt.Errorf("unsupported request type %T", req)
}
Expand Down
1 change: 1 addition & 0 deletions server/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type Orchestrator interface {
AudioToText(ctx context.Context, req worker.GenAudioToTextMultipartRequestBody) (*worker.TextResponse, error)
LLM(ctx context.Context, req worker.GenLLMFormdataRequestBody) (interface{}, error)
SegmentAnything2(ctx context.Context, req worker.GenSegmentAnything2MultipartRequestBody) (*worker.MasksResponse, error)
Lipsync(ctx context.Context, req worker.GenLipsyncMultipartRequestBody) (*worker.VideoBinaryResponse, error)
}

// Balance describes methods for a session's balance maintenance
Expand Down
Loading
Loading