diff --git a/pkg/config/pipeline.go b/pkg/config/pipeline.go index 31585135..5dcdc82d 100644 --- a/pkg/config/pipeline.go +++ b/pkg/config/pipeline.go @@ -98,7 +98,6 @@ type TrackSource struct { MimeType types.MimeType PayloadType webrtc.PayloadType ClockRate uint32 - EOSFunc func() } type AudioConfig struct { diff --git a/pkg/gstreamer/bin.go b/pkg/gstreamer/bin.go index 38d08a9a..abfd0ab0 100644 --- a/pkg/gstreamer/bin.go +++ b/pkg/gstreamer/bin.go @@ -17,6 +17,7 @@ package gstreamer import ( "fmt" "sync" + "time" "github.com/tinyzimmer/go-gst/gst" @@ -27,6 +28,7 @@ import ( // Bins are designed to hold a single stream, with any number of sources and sinks type Bin struct { *Callbacks + *StateManager pipeline *gst.Pipeline mu sync.Mutex @@ -34,7 +36,7 @@ type Bin struct { latency uint64 linkFunc func() error - eosFunc func() + eosFunc func() bool getSrcPad func(string) *gst.Pad getSinkPad func(string) *gst.Pad @@ -48,80 +50,74 @@ type Bin struct { func (b *Bin) NewBin(name string) *Bin { return &Bin{ - Callbacks: b.Callbacks, - pipeline: b.pipeline, - bin: gst.NewBin(name), - pads: make(map[string]*gst.GhostPad), + Callbacks: b.Callbacks, + StateManager: b.StateManager, + pipeline: b.pipeline, + bin: gst.NewBin(name), + pads: make(map[string]*gst.GhostPad), } } // Add src as a source of b. This should only be called once for each source bin func (b *Bin) AddSourceBin(src *Bin) error { - b.mu.Lock() - defer b.mu.Unlock() - - src.mu.Lock() - alreadyAdded := src.added - src.added = true - src.mu.Unlock() - if alreadyAdded { - return errors.ErrBinAlreadyAdded - } - - b.srcs = append(b.srcs, src) - if err := b.pipeline.Add(src.bin.Element); err != nil { - return errors.ErrGstPipelineError(err) - } + logger.Debugw(fmt.Sprintf("adding src %s to %s", src.bin.GetName(), b.bin.GetName())) + return b.addBin(src, gst.PadDirectionSource) +} - if b.bin.GetState() == gst.StatePlaying { - if err := src.link(); err != nil { - return err - } +// Add src as a sink of b. This should only be called once for each sink bin +func (b *Bin) AddSinkBin(sink *Bin) error { + logger.Debugw(fmt.Sprintf("adding sink %s to %s", sink.bin.GetName(), b.bin.GetName())) + return b.addBin(sink, gst.PadDirectionSink) +} - src.mu.Lock() - err := linkPeersLocked(src, b) - src.mu.Unlock() - if err != nil { - return err - } +func (b *Bin) addBin(bin *Bin, direction gst.PadDirection) error { + b.LockStateShared() + defer b.UnlockStateShared() - if err = src.bin.SetState(gst.StatePlaying); err != nil { - return err - } + state := b.GetStateLocked() + if state > StateRunning { + return nil } - return nil -} - -// Add src as a sink of b. This should only be called once for each sink bin -func (b *Bin) AddSinkBin(sink *Bin) error { b.mu.Lock() defer b.mu.Unlock() - sink.mu.Lock() - alreadyAdded := sink.added - sink.added = true - sink.mu.Unlock() + bin.mu.Lock() + alreadyAdded := bin.added + bin.added = true + bin.mu.Unlock() if alreadyAdded { return errors.ErrBinAlreadyAdded } - b.sinks = append(b.sinks, sink) - if err := b.pipeline.Add(sink.bin.Element); err != nil { + if direction == gst.PadDirectionSource { + b.srcs = append(b.srcs, bin) + } else { + b.sinks = append(b.sinks, bin) + } + + if err := b.pipeline.Add(bin.bin.Element); err != nil { return errors.ErrGstPipelineError(err) } - if b.bin.GetState() == gst.StatePlaying { - if err := sink.link(); err != nil { - return err - } + if state == StateBuilding { + return nil + } - sink.mu.Lock() - err := linkPeersLocked(b, sink) - sink.mu.Unlock() - if err != nil { - return err - } + if err := bin.link(); err != nil { + return err + } + + var err error + bin.mu.Lock() + if direction == gst.PadDirectionSource { + err = linkPeersLocked(bin, b) + } else { + err = linkPeersLocked(b, bin) + } + bin.mu.Unlock() + if err != nil { + return err } return nil @@ -153,107 +149,128 @@ func (b *Bin) AddElements(elements ...*gst.Element) error { } func (b *Bin) RemoveSourceBin(name string) (bool, error) { + return b.removeBin(name, gst.PadDirectionSource) +} + +func (b *Bin) RemoveSinkBin(name string) (bool, error) { + return b.removeBin(name, gst.PadDirectionSink) +} + +func (b *Bin) removeBin(name string, direction gst.PadDirection) (bool, error) { + b.LockStateShared() + defer b.UnlockStateShared() + b.mu.Lock() defer b.mu.Unlock() - var src *Bin - for i, s := range b.srcs { - if s.bin.GetName() == name { - src = s - b.srcs = append(b.srcs[:i], b.srcs[i+1:]...) - break + var bin *Bin + if direction == gst.PadDirectionSource { + for i, s := range b.srcs { + if s.bin.GetName() == name { + bin = s + b.srcs = append(b.srcs[:i], b.srcs[i+1:]...) + break + } } - removed, err := s.RemoveSourceBin(name) - if removed || err != nil { - return removed, err + } else { + for i, s := range b.sinks { + if s.bin.GetName() == name { + bin = s + b.sinks = append(b.sinks[:i], b.sinks[i+1:]...) + break + } } } - if src == nil { + if bin == nil { return false, nil } - if b.bin.GetState() != gst.StatePlaying { - if err := b.pipeline.Remove(src.bin.Element); err != nil { + state := b.GetStateLocked() + if state > StateRunning { + return true, nil + } + + if state == StateBuilding { + if err := b.pipeline.Remove(bin.bin.Element); err != nil { return false, errors.ErrGstPipelineError(err) } return true, nil } - if err := src.bin.SetState(gst.StateNull); err != nil { - return false, err + if direction == gst.PadDirectionSource { + b.probeRemoveSource(bin) + } else { + b.probeRemoveSink(bin) } + return true, nil +} + +func (b *Bin) probeRemoveSource(src *Bin) { src.mu.Lock() - srcPad, sinkPad := getGhostPads(src, b) + srcGhostPad, sinkGhostPad := getGhostPads(src, b) src.mu.Unlock() - srcPad.Unlink(sinkPad.Pad) - if err := b.pipeline.Remove(src.bin.Element); err != nil { - return false, errors.ErrGstPipelineError(err) - } - - b.bin.RemovePad(sinkPad.Pad) - b.elements[0].ReleaseRequestPad(sinkPad.GetTarget()) - return true, nil -} + srcGhostPad.AddProbe(gst.PadProbeTypeIdle, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn { + sinkPad := sinkGhostPad.GetTarget() + b.elements[0].ReleaseRequestPad(sinkPad) -func (b *Bin) RemoveSinkBin(name string) (bool, error) { - b.mu.Lock() - defer b.mu.Unlock() + srcGhostPad.Unlink(sinkGhostPad.Pad) + b.bin.RemovePad(sinkGhostPad.Pad) - var sink *Bin - for i, s := range b.sinks { - if s.bin.GetName() == name { - sink = s - b.sinks = append(b.sinks[:i], b.sinks[i+1:]...) - break - } - removed, err := s.RemoveSinkBin(name) - if removed || err != nil { - return removed, err + if err := b.pipeline.Remove(src.bin.Element); err != nil { + b.OnError(err) } - } - if sink == nil { - return false, nil - } - if b.bin.GetState() != gst.StatePlaying { - if err := b.pipeline.Remove(sink.bin.Element); err != nil { - return false, errors.ErrGstPipelineError(err) + if err := src.bin.SetState(gst.StateNull); err != nil { + logger.Warnw(fmt.Sprintf("failed to change %s state", src.bin.GetName()), err) } - return true, nil - } + return gst.PadProbeRemove + }) +} +func (b *Bin) probeRemoveSink(sink *Bin) { sink.mu.Lock() - srcPad, sinkPad := getGhostPads(b, sink) + srcGhostPad, sinkGhostPad := getGhostPads(b, sink) sink.mu.Unlock() - srcPad.AddProbe(gst.PadProbeTypeBlockDownstream, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn { - srcPad.Unlink(sinkPad.Pad) - sinkPad.Pad.SendEvent(gst.NewEOSEvent()) + srcGhostPad.AddProbe(gst.PadProbeTypeBlockDownstream, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn { + srcGhostPad.Unlink(sinkGhostPad.Pad) + sinkGhostPad.Pad.SendEvent(gst.NewEOSEvent()) b.mu.Lock() err := b.pipeline.Remove(sink.bin.Element) b.mu.Unlock() + if err != nil { b.OnError(errors.ErrGstPipelineError(err)) return gst.PadProbeRemove } - if err = sink.bin.SetState(gst.StateNull); err != nil { + if err = sink.SetState(gst.StateNull); err != nil { logger.Warnw(fmt.Sprintf("failed to change %s state", sink.bin.GetName()), err) } - b.elements[len(b.elements)-1].ReleaseRequestPad(srcPad.GetTarget()) - b.bin.RemovePad(srcPad.Pad) + b.elements[len(b.elements)-1].ReleaseRequestPad(srcGhostPad.GetTarget()) + b.bin.RemovePad(srcGhostPad.Pad) return gst.PadProbeRemove }) - - return true, nil } func (b *Bin) SetState(state gst.State) error { - return b.bin.SetState(state) + stateErr := make(chan error, 1) + go func() { + stateErr <- b.bin.SetState(state) + }() + select { + case <-time.After(stateChangeTimeout): + return errors.ErrPipelineFrozen + case err := <-stateErr: + if err != nil { + return errors.ErrGstPipelineError(err) + } + } + return nil } // Set a custom linking function for this bin's elements (used when you need to modify chain functions) @@ -280,14 +297,39 @@ func (b *Bin) SetGetSinkPad(f func(sinkName string) *gst.Pad) { b.getSinkPad = f } -// Set a custom EOS function (used for appsrc) -func (b *Bin) SetEOSFunc(f func()) { +// Set a custom EOS function (used for appsrc, input-selector). If it returns true, EOS will also be sent to src bins +func (b *Bin) SetEOSFunc(f func() bool) { b.mu.Lock() defer b.mu.Unlock() b.eosFunc = f } +func (b *Bin) sendEOS() { + b.mu.Lock() + eosFunc := b.eosFunc + srcs := b.srcs + b.mu.Unlock() + + if eosFunc != nil && !eosFunc() { + return + } + + if len(srcs) > 0 { + var wg sync.WaitGroup + wg.Add(len(b.srcs)) + for _, src := range srcs { + go func(s *Bin) { + s.sendEOS() + wg.Done() + }(src) + } + wg.Wait() + } else if len(b.elements) > 0 { + b.bin.SendEvent(gst.NewEOSEvent()) + } +} + // ----- Internal ----- func (b *Bin) link() error { @@ -364,20 +406,35 @@ func (b *Bin) link() error { } func linkPeersLocked(src, sink *Bin) error { - srcPad, sinkPad, err := createGhostPads(src, sink) + srcPad, sinkPad, err := createGhostPads(src, sink, nil) if err != nil { return err } - if src.bin.GetState() == gst.StatePlaying { - srcPad.AddProbe(gst.PadProbeTypeBlockDownstream, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn { - if err = sink.bin.SetState(gst.StatePlaying); err != nil { - src.OnError(errors.ErrGstPipelineError(err)) - return gst.PadProbeUnhandled - } + srcState := src.bin.GetState() + sinkState := sink.bin.GetState() - return gst.PadProbeRemove - }) + if srcState != sinkState { + if srcState == gst.StateNull { + srcPad.AddProbe(gst.PadProbeTypeBlockDownstream, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn { + if padReturn := srcPad.Link(sinkPad.Pad); padReturn != gst.PadLinkOK { + logger.Errorw("failed to link", errors.ErrPadLinkFailed(src.bin.GetName(), sink.bin.GetName(), padReturn.String())) + } + return gst.PadProbeRemove + }) + return src.SetState(gst.StatePlaying) + } + + if sinkState == gst.StateNull { + srcPad.AddProbe(gst.PadProbeTypeBlockDownstream, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn { + if err = sink.SetState(gst.StatePlaying); err != nil { + src.OnError(errors.ErrGstPipelineError(err)) + return gst.PadProbeUnhandled + } + + return gst.PadProbeRemove + }) + } } if padReturn := srcPad.Link(sinkPad.Pad); padReturn != gst.PadLinkOK { @@ -401,7 +458,7 @@ func (b *Bin) linkPeersWithQueueLocked(src, sink *Bin) error { return err } - srcPad, sinkPad, err := createGhostPadsWithQueue(src, sink, queue) + srcPad, sinkPad, err := createGhostPads(src, sink, queue) if err != nil { return err } @@ -412,21 +469,6 @@ func (b *Bin) linkPeersWithQueueLocked(src, sink *Bin) error { return nil } -func (b *Bin) sendEOS() { - b.mu.Lock() - defer b.mu.Unlock() - - if b.eosFunc != nil { - b.eosFunc() - } else if len(b.srcs) > 0 { - for _, src := range b.srcs { - src.sendEOS() - } - } else if len(b.elements) > 0 { - b.bin.SendEvent(gst.NewEOSEvent()) - } -} - func getPeerSrcs(srcs []*Bin) []*Bin { flattened := make([]*Bin, 0, len(srcs)) for _, src := range srcs { @@ -450,3 +492,13 @@ func getPeerSinks(sinks []*Bin) []*Bin { } return flattened } + +func getGhostPads(src, sink *Bin) (*gst.GhostPad, *gst.GhostPad) { + srcPad := src.pads[sink.bin.GetName()] + sinkPad := sink.pads[src.bin.GetName()] + + delete(src.pads, sink.bin.GetName()) + delete(sink.pads, src.bin.GetName()) + + return srcPad, sinkPad +} diff --git a/pkg/gstreamer/pads.go b/pkg/gstreamer/pads.go index 03fb8d40..09e36482 100644 --- a/pkg/gstreamer/pads.go +++ b/pkg/gstreamer/pads.go @@ -21,9 +21,10 @@ import ( "github.com/tinyzimmer/go-gst/gst" "github.com/livekit/egress/pkg/errors" + "github.com/livekit/protocol/logger" ) -func createGhostPads(src, sink *Bin) (*gst.GhostPad, *gst.GhostPad, error) { +func createGhostPads(src, sink *Bin, queue *gst.Element) (*gst.GhostPad, *gst.GhostPad, error) { srcName := src.bin.GetName() sinkName := sink.bin.GetName() @@ -36,165 +37,183 @@ func createGhostPads(src, sink *Bin) (*gst.GhostPad, *gst.GhostPad, error) { src.pads[sinkName] = srcGhostPad src.bin.AddPad(srcGhostPad.Pad) - sinkGhostPad := gst.NewGhostPad(fmt.Sprintf("%s_%s_src", srcName, sinkName), sinkPad) - sink.pads[srcName] = sinkGhostPad - sink.bin.AddPad(sinkGhostPad.Pad) - - return srcGhostPad, sinkGhostPad, nil -} - -func createGhostPadsWithQueue(src, sink *Bin, queue *gst.Element) (*gst.GhostPad, *gst.GhostPad, error) { - srcName := src.bin.GetName() - sinkName := sink.bin.GetName() + if queue != nil { + if padReturn := queue.GetStaticPad("src").Link(sinkPad); padReturn != gst.PadLinkOK { + return nil, nil, errors.ErrPadLinkFailed(queue.GetName(), sinkName, padReturn.String()) + } - srcPad, sinkPad, err := getPads(src, sink) - if err != nil { - return nil, nil, err + sinkGhostPad := gst.NewGhostPad(fmt.Sprintf("%s_%s_src", srcName, sinkName), queue.GetStaticPad("sink")) + sink.pads[srcName] = sinkGhostPad + sink.bin.AddPad(sinkGhostPad.Pad) + return srcGhostPad, sinkGhostPad, nil } - if padReturn := queue.GetStaticPad("src").Link(sinkPad); padReturn != gst.PadLinkOK { - return nil, nil, errors.ErrPadLinkFailed(queue.GetName(), sinkName, padReturn.String()) - } - - srcGhostPad := gst.NewGhostPad(fmt.Sprintf("%s_%s_sink", srcName, sinkName), srcPad) - src.pads[sinkName] = srcGhostPad - src.bin.AddPad(srcGhostPad.Pad) - sinkGhostPad := gst.NewGhostPad(fmt.Sprintf("%s_%s_src", srcName, sinkName), queue.GetStaticPad("sink")) + sinkGhostPad := gst.NewGhostPad(fmt.Sprintf("%s_%s_src", srcName, sinkName), sinkPad) sink.pads[srcName] = sinkGhostPad sink.bin.AddPad(sinkGhostPad.Pad) - return srcGhostPad, sinkGhostPad, nil } -func getGhostPads(src, sink *Bin) (*gst.GhostPad, *gst.GhostPad) { - srcPad := src.pads[sink.bin.GetName()] - sinkPad := sink.pads[src.bin.GetName()] - - delete(src.pads, sink.bin.GetName()) - delete(sink.pads, src.bin.GetName()) - - return srcPad, sinkPad -} - func getPads(src, sink *Bin) (*gst.Pad, *gst.Pad, error) { var srcPad, sinkPad *gst.Pad - srcElement := src.elements[len(src.elements)-1] - sinkElement := sink.elements[0] - + var srcTemplates, sinkTemplates []*padTemplate if src.getSinkPad != nil { srcPad = src.getSinkPad(sink.bin.GetName()) - for _, padTemplate := range sinkElement.GetPadTemplates() { - if padTemplate.Direction() == gst.PadDirectionSink { - return srcPad, getPad(sinkElement, padTemplate), nil - } - } + } else { + srcTemplates = src.getPadTemplates(gst.PadDirectionSource) } if sink.getSrcPad != nil { sinkPad = sink.getSrcPad(src.bin.GetName()) - for _, padTemplate := range srcElement.GetPadTemplates() { - if padTemplate.Direction() == gst.PadDirectionSource { - return getPad(srcElement, padTemplate), sinkPad, nil - } - } + } else { + sinkTemplates = sink.getPadTemplates(gst.PadDirectionSink) } - srcPrimary, srcSecondary, srcAny := getPadTemplates(src.elements, gst.PadDirectionSource) - sinkPrimary, sinkSecondary, sinkAny := getPadTemplates(sink.elements, gst.PadDirectionSink) - - for srcCaps, srcTemplate := range srcPrimary { - if sinkTemplate, ok := sinkPrimary[srcCaps]; ok { - return getPad(srcElement, srcTemplate), getPad(sinkElement, sinkTemplate), nil + switch { + case srcPad != nil && sinkPad != nil: + return srcPad, sinkPad, nil + case srcPad != nil && len(sinkTemplates) == 1: + return srcPad, sinkTemplates[0].toPad(), nil + case sinkPad != nil && len(srcTemplates) == 1: + return srcTemplates[0].toPad(), sinkPad, nil + case len(srcTemplates) >= 1 && len(srcTemplates) >= 1: + for _, srcTemplate := range srcTemplates { + if sinkTemplate := srcTemplate.findDirectMatch(sinkTemplates); sinkTemplate != nil { + return srcTemplate.toPad(), sinkTemplate.toPad(), nil + } } - } - for dataType, srcTemplate := range srcSecondary { - if sinkTemplate, ok := sinkSecondary[dataType]; ok { - return getPad(srcElement, srcTemplate), getPad(sinkElement, sinkTemplate), nil + for _, srcTemplate := range srcTemplates { + if sinkTemplate := srcTemplate.findAnyMatch(sinkTemplates); sinkTemplate != nil { + return srcTemplate.toPad(), sinkTemplate.toPad(), nil + } } } - if srcAny != nil && sinkAny != nil { - return getPad(srcElement, srcAny), getPad(sinkElement, sinkAny), nil - } + logger.Warnw("could not match pads", nil, "srcTemplates", srcTemplates, "sinkTemplates", sinkTemplates) return nil, nil, errors.ErrGhostPadFailed } -func getPad(e *gst.Element, template *gst.PadTemplate) *gst.Pad { - if template.Presence() == gst.PadPresenceAlways { - return e.GetStaticPad(template.Name()) +type padTemplate struct { + element *gst.Element + template *gst.PadTemplate + capsNames map[string]struct{} + dataTypes map[string]struct{} +} + +func (p *padTemplate) toPad() *gst.Pad { + if p.template.Presence() == gst.PadPresenceAlways { + return p.element.GetStaticPad(p.template.Name()) } else { - return e.GetRequestPad(template.Name()) + return p.element.GetRequestPad(p.template.Name()) + } +} + +func (p *padTemplate) findDirectMatch(others []*padTemplate) *padTemplate { + for _, other := range others { + for capsName := range p.capsNames { + if _, ok := other.capsNames[capsName]; ok { + return other + } + } + for dataType := range p.dataTypes { + if _, ok := other.dataTypes[dataType]; ok { + return other + } + } } + return nil } -func getPadTemplates(elements []*gst.Element, direction gst.PadDirection) ( - map[string]*gst.PadTemplate, - map[string]*gst.PadTemplate, - *gst.PadTemplate, -) { - primary := make(map[string]*gst.PadTemplate) - secondary := make(map[string]*gst.PadTemplate) - var anyTemplate *gst.PadTemplate +func (p *padTemplate) findAnyMatch(others []*padTemplate) *padTemplate { + for _, other := range others { + if _, ok := p.dataTypes["ANY"]; ok { + return other + } + if _, ok := other.dataTypes["ANY"]; ok { + return other + } + } + return nil +} - var i int +func (b *Bin) getPadTemplates(direction gst.PadDirection) []*padTemplate { + var element *gst.Element if direction == gst.PadDirectionSource { - i = len(elements) - 1 + element = b.elements[len(b.elements)-1] + } else { + element = b.elements[0] } - for i >= 0 && i < len(elements) { - padTemplates := elements[i].GetPadTemplates() - for _, padTemplate := range padTemplates { - if padTemplate.Direction() == direction { - caps := padTemplate.Caps() - - if caps.IsAny() { - if strings.HasPrefix(padTemplate.Name(), direction.String()) { - // most generic pad - if anyTemplate == nil { - anyTemplate = padTemplate - } else { - continue - } + allTemplates := element.GetPadTemplates() + templates := make([]*padTemplate, 0) + + for _, template := range allTemplates { + if template.Direction() == direction { + t := &padTemplate{ + element: element, + template: template, + capsNames: make(map[string]struct{}), + dataTypes: make(map[string]struct{}), + } + + caps := template.Caps() + if caps.IsAny() { + if strings.HasPrefix(template.Name(), direction.String()) { + // src/src_%u/sink/sink_%u pad + capsNames, dataTypes, ok := b.getTypes(direction) + if ok { + t.capsNames = capsNames + t.dataTypes = dataTypes } else { - // any caps but associated name - dataType := padTemplate.Name() - if strings.HasSuffix(dataType, "_%u") { - dataType = dataType[:len(dataType)-3] - } - if anyTemplate != nil { - secondary[dataType] = anyTemplate - return primary, secondary, nil - } - secondary[dataType] = padTemplate + t.dataTypes["ANY"] = struct{}{} } } else { - // specified caps + // audio/audio_%u/video/video_%u pad + dataType := template.Name() + if strings.HasSuffix(dataType, "_%u") { + dataType = dataType[:len(dataType)-3] + } + t.dataTypes[dataType] = struct{}{} + } + } else { + // pad has caps + splitCaps := strings.Split(caps.String(), "; ") + for _, c := range splitCaps { + capsName := strings.SplitN(c, ",", 2)[0] + t.capsNames[capsName] = struct{}{} + t.dataTypes[strings.Split(capsName, "/")[0]] = struct{}{} + } + } + + templates = append(templates, t) + } + } + + return templates +} + +func (b *Bin) getTypes(direction gst.PadDirection) (map[string]struct{}, map[string]struct{}, bool) { + var i int + if direction == gst.PadDirectionSource { + i = len(b.elements) - 1 + } + + for i >= 0 && i < len(b.elements) { + allTemplates := b.elements[i].GetPadTemplates() + for _, template := range allTemplates { + if template.Direction() == gst.PadDirectionSource { + if caps := template.Caps(); !caps.IsAny() { + capsNames := make(map[string]struct{}) + dataTypes := make(map[string]struct{}) splitCaps := strings.Split(caps.String(), ";") for _, c := range splitCaps { capsName := strings.SplitN(c, ",", 2)[0] - dataType := strings.Split(capsName, "/")[0] - if anyTemplate != nil { - primary[capsName] = anyTemplate - secondary[dataType] = anyTemplate - } else { - primary[capsName] = padTemplate - secondary[dataType] = padTemplate - } - } - if anyTemplate != nil { - return primary, secondary, anyTemplate + capsNames[capsName] = struct{}{} + dataTypes[strings.Split(capsName, "/")[0]] = struct{}{} } + return capsNames, dataTypes, true } } } - if anyTemplate == nil { - for _, template := range primary { - return primary, secondary, template - } - for _, template := range secondary { - return primary, secondary, template - } - return primary, secondary, nil - } if direction == gst.PadDirectionSource { i-- @@ -203,5 +222,21 @@ func getPadTemplates(elements []*gst.Element, direction gst.PadDirection) ( } } - return primary, secondary, anyTemplate + if direction == gst.PadDirectionSource { + for _, src := range b.srcs { + capsNames, dataTypes, ok := src.getTypes(direction) + if ok { + return capsNames, dataTypes, true + } + } + } else { + for _, sink := range b.sinks { + capsNames, dataTypes, ok := sink.getTypes(direction) + if ok { + return capsNames, dataTypes, true + } + } + } + + return nil, nil, false } diff --git a/pkg/gstreamer/pipeline.go b/pkg/gstreamer/pipeline.go index 61e23795..d6d44b30 100644 --- a/pkg/gstreamer/pipeline.go +++ b/pkg/gstreamer/pipeline.go @@ -15,12 +15,18 @@ package gstreamer import ( + "time" + "github.com/frostbyte73/core" "github.com/tinyzimmer/go-glib/glib" "github.com/tinyzimmer/go-gst/gst" "github.com/livekit/egress/pkg/errors" - "github.com/livekit/protocol/logger" +) + +const ( + stateChangeTimeout = time.Second * 15 + stopTimeout = time.Second * 30 ) type Pipeline struct { @@ -30,8 +36,8 @@ type Pipeline struct { binsAdded bool elementsAdded bool - started core.Fuse running chan struct{} + stopped core.Fuse } // A pipeline can have either elements or src and sink bins. If you add both you will get a wrong hierarchy error @@ -44,15 +50,16 @@ func NewPipeline(name string, latency uint64, callbacks *Callbacks) (*Pipeline, return &Pipeline{ Bin: &Bin{ - Callbacks: callbacks, - pipeline: pipeline, - bin: pipeline.Bin, - latency: latency, - queues: make(map[string]*gst.Element), + Callbacks: callbacks, + StateManager: &StateManager{}, + pipeline: pipeline, + bin: pipeline.Bin, + latency: latency, + queues: make(map[string]*gst.Element), }, loop: glib.NewMainLoop(glib.MainContextDefault(), false), - started: core.NewFuse(), running: make(chan struct{}), + stopped: core.NewFuse(), }, nil } @@ -97,22 +104,35 @@ func (p *Pipeline) SetWatch(watch func(msg *gst.Message) bool) { } func (p *Pipeline) SetState(state gst.State) error { - if err := p.pipeline.SetState(state); err != nil { - return errors.ErrGstPipelineError(err) + p.mu.Lock() + defer p.mu.Unlock() + + stateErr := make(chan error, 1) + go func() { + stateErr <- p.pipeline.SetState(state) + }() + select { + case <-time.After(stateChangeTimeout): + return errors.ErrPipelineFrozen + case err := <-stateErr: + if err != nil { + return errors.ErrGstPipelineError(err) + } } + return nil } func (p *Pipeline) Run() error { - p.started.Once(func() { + if _, ok := p.UpgradeState(StateStarted); ok { if err := p.SetState(gst.StatePlaying); err != nil { - p.OnError(err) - return + return err + } + if _, ok = p.UpgradeState(StateRunning); ok { + p.loop.Run() } - logger.Infow("running") - p.loop.Run() close(p.running) - }) + } // wait <-p.running @@ -120,19 +140,34 @@ func (p *Pipeline) Run() error { } func (p *Pipeline) SendEOS() { - p.sendEOS() + old, ok := p.UpgradeState(StateEOS) + if ok { + if old >= StateRunning { + p.sendEOS() + } else { + p.Stop() + } + } } func (p *Pipeline) Stop() { - defer p.loop.Quit() - if err := p.SetState(gst.StateNull); err != nil { - p.OnError(err) + old, ok := p.UpgradeState(StateStopping) + if !ok { return } + if err := p.OnStop(); err != nil { p.OnError(err) - return } + + if old >= StateRunning { + p.loop.Quit() + } + + p.UpgradeState(StateFinished) + go func() { + _ = p.pipeline.SetState(gst.StateNull) + }() } func (p *Pipeline) DebugBinToDotData(details gst.DebugGraphDetails) string { diff --git a/pkg/gstreamer/state.go b/pkg/gstreamer/state.go new file mode 100644 index 00000000..b17654a9 --- /dev/null +++ b/pkg/gstreamer/state.go @@ -0,0 +1,98 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gstreamer + +import ( + "fmt" + "sync" + + "github.com/livekit/protocol/logger" +) + +type State int + +const ( + StateBuilding State = iota + StateStarted + StateRunning + StateEOS + StateStopping + StateFinished +) + +type StateManager struct { + lock sync.RWMutex + state State +} + +func (s *StateManager) GetState() State { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.state +} + +func (s *StateManager) GetStateLocked() State { + return s.state +} + +func (s *StateManager) LockState() { + s.lock.Lock() +} + +func (s *StateManager) UnlockState() { + s.lock.Unlock() +} + +func (s *StateManager) LockStateShared() { + s.lock.RLock() +} + +func (s *StateManager) UnlockStateShared() { + s.lock.RUnlock() +} + +func (s *StateManager) UpgradeState(state State) (State, bool) { + s.lock.Lock() + defer s.lock.Unlock() + + old := s.state + if old >= state { + return old, false + } else { + logger.Debugw(fmt.Sprintf("pipeline state %v -> %v", old, state)) + s.state = state + return old, true + } +} + +func (s State) String() string { + switch s { + case StateBuilding: + return "building" + case StateStarted: + return "starting" + case StateRunning: + return "running" + case StateEOS: + return "eos" + case StateStopping: + return "stopping" + case StateFinished: + return "finished" + default: + return "unknown" + } +} diff --git a/pkg/pipeline/builder/audio.go b/pkg/pipeline/builder/audio.go index 2198cc2f..55c101d3 100644 --- a/pkg/pipeline/builder/audio.go +++ b/pkg/pipeline/builder/audio.go @@ -51,6 +51,14 @@ func BuildAudioBin(pipeline *gstreamer.Pipeline, p *config.PipelineConfig) (*gst if err = b.AddElement(tee); err != nil { return nil, err } + } else { + queue, err := gstreamer.BuildQueue("audio_queue", p.Latency, true) + if err != nil { + return nil, errors.ErrGstPipelineError(err) + } + if err = b.AddElement(queue); err != nil { + return nil, err + } } return b, nil @@ -106,7 +114,9 @@ func buildAudioAppSrcBin(audioBin *gstreamer.Bin, p *config.PipelineConfig) erro track := p.AudioTrack b := audioBin.NewBin(track.TrackID) - b.SetEOSFunc(track.EOSFunc) + b.SetEOSFunc(func() bool { + return false + }) if err := audioBin.AddSourceBin(b); err != nil { return err } diff --git a/pkg/pipeline/builder/video.go b/pkg/pipeline/builder/video.go index 0eca1e85..ffa27020 100644 --- a/pkg/pipeline/builder/video.go +++ b/pkg/pipeline/builder/video.go @@ -57,6 +57,14 @@ func BuildVideoBin(pipeline *gstreamer.Pipeline, p *config.PipelineConfig) (*gst if err = b.AddElement(tee); err != nil { return nil, err } + } else { + queue, err := gstreamer.BuildQueue("video_queue", p.Latency, true) + if err != nil { + return nil, errors.ErrGstPipelineError(err) + } + if err = b.AddElement(queue); err != nil { + return nil, err + } } return b, nil @@ -169,7 +177,9 @@ func (v *videoSDKBin) buildVideoAppSrcBin(videoBin *gstreamer.Bin, p *config.Pip track := p.VideoTrack b := videoBin.NewBin(track.TrackID) - b.SetEOSFunc(track.EOSFunc) + b.SetEOSFunc(func() bool { + return false + }) if err := videoBin.AddSourceBin(b); err != nil { return err } diff --git a/pkg/pipeline/controller.go b/pkg/pipeline/controller.go index 618c4899..2f29fd41 100644 --- a/pkg/pipeline/controller.go +++ b/pkg/pipeline/controller.go @@ -56,9 +56,8 @@ type Controller struct { gstLogger *zap.SugaredLogger limitTimer *time.Timer playing core.Fuse - eosSent core.Fuse + eos core.Fuse stopped core.Fuse - eosTimer *time.Timer } func New(ctx context.Context, conf *config.PipelineConfig) (*Controller, error) { @@ -73,7 +72,7 @@ func New(ctx context.Context, conf *config.PipelineConfig) (*Controller, error) }, gstLogger: logger.GetLogger().(*logger.ZapLogger).ToZap().WithOptions(zap.WithCaller(false)), playing: core.NewFuse(), - eosSent: core.NewFuse(), + eos: core.NewFuse(), stopped: core.NewFuse(), } c.callbacks.SetOnError(c.OnError) @@ -109,15 +108,22 @@ func New(ctx context.Context, conf *config.PipelineConfig) (*Controller, error) } func (c *Controller) BuildPipeline() error { - logger.Debugw("building pipeline") - p, err := gstreamer.NewPipeline(pipelineName, c.Latency, c.callbacks) if err != nil { return errors.ErrGstPipelineError(err) } p.SetWatch(c.messageWatch) - p.AddOnStop(c.OnStop) + p.AddOnStop(func() error { + c.stopped.Break() + return nil + }) + if c.SourceType == types.SourceTypeSDK { + p.SetEOSFunc(func() bool { + c.src.(*source.SDKSource).CloseWriters() + return true + }) + } if c.AudioEnabled { audioBin, err := builder.BuildAudioBin(p, c.PipelineConfig) @@ -180,6 +186,9 @@ func (c *Controller) Run(ctx context.Context) *livekit.EgressInfo { now := time.Now().UnixNano() c.Info.UpdatedAt = now c.Info.EndedAt = now + if c.SourceType == types.SourceTypeSDK { + c.updateDuration(c.src.GetEndedAt()) + } // update status if c.Info.Error != "" { @@ -274,7 +283,7 @@ func (c *Controller) startSessionLimitTimer(ctx context.Context) { if c.playing.IsBroken() { c.SendEOS(ctx) } else { - c.Stop() + c.p.Stop() } }) } @@ -440,11 +449,12 @@ func (c *Controller) SendEOS(ctx context.Context) { ctx, span := tracer.Start(ctx, "Pipeline.SendEOS") defer span.End() - c.eosSent.Once(func() { + c.eos.Once(func() { + logger.Debugw("Sending EOS") + if c.limitTimer != nil { c.limitTimer.Stop() } - switch c.Info.Status { case livekit.EgressStatus_EGRESS_STARTING: c.Info.Status = livekit.EgressStatus_EGRESS_ABORTED @@ -452,13 +462,13 @@ func (c *Controller) SendEOS(ctx context.Context) { case livekit.EgressStatus_EGRESS_ABORTED, livekit.EgressStatus_EGRESS_FAILED: - c.Stop() + c.p.Stop() case livekit.EgressStatus_EGRESS_ACTIVE: c.Info.UpdatedAt = time.Now().UnixNano() if c.Info.Error != "" { c.Info.Status = livekit.EgressStatus_EGRESS_FAILED - c.Stop() + c.p.Stop() } else { c.Info.Status = livekit.EgressStatus_EGRESS_ENDING c.OnUpdate(ctx, c.Info) @@ -467,24 +477,7 @@ func (c *Controller) SendEOS(ctx context.Context) { case livekit.EgressStatus_EGRESS_ENDING, livekit.EgressStatus_EGRESS_LIMIT_REACHED: - go func() { - logger.Infow("sending EOS to pipeline") - - c.eosTimer = time.AfterFunc(eosTimeout, func() { - logger.Errorw("pipeline frozen", nil, "stream", !c.FinalizationRequired) - if c.Debug.EnableProfiling { - c.uploadDebugFiles() - } - - if c.FinalizationRequired { - c.OnError(errors.ErrPipelineFrozen) - } else { - c.Stop() - } - }) - - c.p.SendEOS() - }() + go c.p.SendEOS() } switch c.src.(type) { @@ -495,30 +488,17 @@ func (c *Controller) SendEOS(ctx context.Context) { } func (c *Controller) OnError(err error) { - if c.Info.Error == "" && (!c.eosSent.IsBroken() || c.FinalizationRequired) { - c.Info.Error = err.Error() - } - go c.Stop() -} - -func (c *Controller) Stop() { - c.stopped.Once(c.p.Stop) -} - -func (c *Controller) OnStop() error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.eosTimer != nil { - c.eosTimer.Stop() + if errors.Is(err, errors.ErrPipelineFrozen) { + if c.Debug.EnableProfiling { + c.uploadDebugFiles() + } } - switch c.src.(type) { - case *source.SDKSource: - c.updateDuration(c.src.GetEndedAt()) + if c.Info.Error == "" && (!c.eos.IsBroken() || c.FinalizationRequired) { + c.Info.Error = err.Error() } - return nil + go c.p.Stop() } func (c *Controller) updateDuration(endedAt int64) { diff --git a/pkg/pipeline/source/sdk.go b/pkg/pipeline/source/sdk.go index 9eb18bf0..983a6c78 100644 --- a/pkg/pipeline/source/sdk.go +++ b/pkg/pipeline/source/sdk.go @@ -93,10 +93,6 @@ func (s *SDKSource) StartRecording() chan struct{} { return s.startRecording } -func (s *SDKSource) GetStartTime() int64 { - return s.sync.GetStartedAt() -} - func (s *SDKSource) Playing(trackID string) { if w := s.getWriterForTrack(trackID); w != nil { w.Play() @@ -107,6 +103,10 @@ func (s *SDKSource) EndRecording() chan struct{} { return s.endRecording } +func (s *SDKSource) GetStartedAt() int64 { + return s.sync.GetStartedAt() +} + func (s *SDKSource) GetEndedAt() int64 { return s.sync.GetEndedAt() } @@ -294,7 +294,6 @@ func (s *SDKSource) onTrackSubscribed(track *webrtc.TrackRemote, pub *lksdk.Remo return } - ts.EOSFunc = s.CloseWriters s.audioWriter = writer s.AudioTrack = ts @@ -314,7 +313,6 @@ func (s *SDKSource) onTrackSubscribed(track *webrtc.TrackRemote, pub *lksdk.Remo return } - ts.EOSFunc = s.CloseWriters s.videoWriter = writer s.VideoTrack = ts diff --git a/pkg/pipeline/source/source.go b/pkg/pipeline/source/source.go index a0b7aed9..49ad6727 100644 --- a/pkg/pipeline/source/source.go +++ b/pkg/pipeline/source/source.go @@ -26,6 +26,7 @@ import ( type Source interface { StartRecording() chan struct{} EndRecording() chan struct{} + GetStartedAt() int64 GetEndedAt() int64 Close() } diff --git a/pkg/pipeline/source/web.go b/pkg/pipeline/source/web.go index 4f1032f8..5e868379 100644 --- a/pkg/pipeline/source/web.go +++ b/pkg/pipeline/source/web.go @@ -95,6 +95,10 @@ func (s *WebSource) EndRecording() chan struct{} { return s.endRecording } +func (s *WebSource) GetStartedAt() int64 { + return time.Now().UnixNano() +} + func (s *WebSource) GetEndedAt() int64 { return time.Now().UnixNano() } diff --git a/pkg/pipeline/watch.go b/pkg/pipeline/watch.go index ea35be63..62d7e9b3 100644 --- a/pkg/pipeline/watch.go +++ b/pkg/pipeline/watch.go @@ -122,8 +122,8 @@ func (c *Controller) messageWatch(msg *gst.Message) bool { var err error switch msg.Type() { case gst.MessageEOS: - logger.Infow("EOS received, stopping pipeline") - c.Stop() + logger.Infow("EOS received") + c.p.Stop() return false case gst.MessageWarning: err = c.handleMessageWarning(msg.ParseWarning()) @@ -164,7 +164,7 @@ func (c *Controller) handleMessageError(gErr *gst.GError) error { switch { case element == elementGstRtmp2Sink: - if strings.HasPrefix(gErr.Error(), "Connection error") && !c.eosSent.IsBroken() { + if strings.HasPrefix(gErr.Error(), "Connection error") && !c.eos.IsBroken() { // try reconnecting ok, err := c.streamBin.ResetStream(name, gErr) if err != nil { @@ -194,7 +194,7 @@ func (c *Controller) handleMessageError(gErr *gst.GError) error { case element == elementSplitMuxSink: // We sometimes get GstSplitMuxSink errors if send EOS before the first media was sent to the mux if message == msgMuxer { - if c.eosSent.IsBroken() { + if c.eos.IsBroken() { logger.Debugw("GstSplitMuxSink failure after sending EOS") return nil } @@ -232,15 +232,10 @@ func (c *Controller) handleMessageStateChanged(msg *gst.Message) { s := msg.Source() if s == pipelineName { - logger.Infow("pipeline playing") - - c.playing.Break() - switch c.SourceType { - case types.SourceTypeSDK: - c.updateStartTime(c.src.(*source.SDKSource).GetStartTime()) - case types.SourceTypeWeb: - c.updateStartTime(time.Now().UnixNano()) - } + c.playing.Once(func() { + logger.Infow("pipeline playing") + c.updateStartTime(c.src.GetStartedAt()) + }) } else if strings.HasPrefix(s, "app_") { s = s[4:] logger.Infow(fmt.Sprintf("%s playing", s))