Skip to content

Commit

Permalink
Drop upstream events when removing src bin (#635)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostbyte73 authored Mar 12, 2024
1 parent 9e29802 commit a9aaeeb
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 70 deletions.
19 changes: 13 additions & 6 deletions pkg/gstreamer/bin.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,24 @@ func (b *Bin) AddElements(elements ...*gst.Element) error {
}

func (b *Bin) RemoveSourceBin(name string) (bool, error) {
logger.Debugw(fmt.Sprintf("removing src %s from %s", name, b.bin.GetName()))
return b.removeBin(name, gst.PadDirectionSource)
}

func (b *Bin) RemoveSinkBin(name string) (bool, error) {
logger.Debugw(fmt.Sprintf("removing sink %s from %s", name, b.bin.GetName()))
return b.removeBin(name, gst.PadDirectionSink)
}

func (b *Bin) removeBin(name string, direction gst.PadDirection) (bool, error) {
b.LockStateShared()
defer b.UnlockStateShared()

state := b.GetStateLocked()
if state > StateRunning {
return true, nil
}

b.mu.Lock()
defer b.mu.Unlock()

Expand All @@ -186,11 +193,6 @@ func (b *Bin) removeBin(name string, direction gst.PadDirection) (bool, error) {
return false, nil
}

state := b.GetStateLocked()
if state > StateRunning {
return true, nil
}

if state == StateBuilding {
if err := b.pipeline.Remove(bin.bin.Element); err != nil {
return false, errors.ErrGstPipelineError(err)
Expand All @@ -216,7 +218,12 @@ func (b *Bin) probeRemoveSource(src *Bin) {
}

sinkPad := sinkGhostPad.GetTarget()
srcGhostPad.AddProbe(gst.PadProbeTypeBlocking, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn {
sinkPad.AddProbe(gst.PadProbeTypeBlockUpstream, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn {
// drop all upstream events
return gst.PadProbeDrop
})

srcGhostPad.AddProbe(gst.PadProbeTypeIdle, func(_ *gst.Pad, _ *gst.PadProbeInfo) gst.PadProbeReturn {
b.elements[0].ReleaseRequestPad(sinkPad)

srcGhostPad.Unlink(sinkGhostPad.Pad)
Expand Down
5 changes: 3 additions & 2 deletions pkg/gstreamer/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import (
)

type Callbacks struct {
mu sync.RWMutex
GstReady chan struct{}
mu sync.RWMutex
GstReady chan struct{}
BuildReady chan struct{}

// upstream callbacks
onError func(error)
Expand Down
1 change: 1 addition & 0 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func (h *Handler) Run() error {
select {
case <-kill:
// kill signal received
h.conf.Info.Details = "service terminated by deployment"
h.controller.SendEOS(ctx)

case res := <-result:
Expand Down
9 changes: 5 additions & 4 deletions pkg/pipeline/builder/audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package builder

import (
"fmt"
"math/rand"
"sync"

"github.com/go-gst/go-gst/gst"
Expand All @@ -34,8 +33,9 @@ type AudioBin struct {
bin *gstreamer.Bin
conf *config.PipelineConfig

mu sync.Mutex
names map[string]string
mu sync.Mutex
nextID int
names map[string]string
}

func BuildAudioBin(pipeline *gstreamer.Pipeline, p *config.PipelineConfig) error {
Expand Down Expand Up @@ -159,7 +159,8 @@ func (b *AudioBin) addAudioAppSrcBin(ts *config.TrackSource) error {
b.mu.Lock()
defer b.mu.Unlock()

name := fmt.Sprintf("%s_%d", ts.TrackID, rand.Int()%1000)
name := fmt.Sprintf("%s_%d", ts.TrackID, b.nextID)
b.nextID++
b.names[ts.TrackID] = name

appSrcBin := b.bin.NewBin(name)
Expand Down
5 changes: 3 additions & 2 deletions pkg/pipeline/builder/video.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package builder

import (
"fmt"
"math/rand"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -46,6 +45,7 @@ type VideoBin struct {
nextPad string

mu sync.Mutex
nextID int
pads map[string]*gst.Pad
names map[string]string
selector *gst.Element
Expand Down Expand Up @@ -282,7 +282,8 @@ func (b *VideoBin) buildSDKInput() error {
}

func (b *VideoBin) addAppSrcBin(ts *config.TrackSource) error {
name := fmt.Sprintf("%s_%d", ts.TrackID, rand.Int()%1000)
name := fmt.Sprintf("%s_%d", ts.TrackID, b.nextID)
b.nextID++

appSrcBin, err := b.buildAppSrcBin(ts, name)
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion pkg/pipeline/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ func New(ctx context.Context, conf *config.PipelineConfig, ioClient rpc.IOInfoCl
c := &Controller{
PipelineConfig: conf,
callbacks: &gstreamer.Callbacks{
GstReady: make(chan struct{}),
GstReady: make(chan struct{}),
BuildReady: make(chan struct{}),
},
ioClient: ioClient,
gstLogger: logger.GetLogger().(logger.ZapLogger).ToZap().WithOptions(zap.WithCaller(false)),
Expand Down Expand Up @@ -186,6 +187,7 @@ func (c *Controller) BuildPipeline() error {
}

c.p = p
close(c.callbacks.BuildReady)
return nil
}

Expand Down
94 changes: 58 additions & 36 deletions pkg/pipeline/source/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type SDKSource struct {
errors chan error

writers map[string]*sdk.AppWriter
subLock sync.RWMutex
active atomic.Int32
closed core.Fuse

Expand All @@ -76,6 +77,7 @@ func NewSDKSource(ctx context.Context, p *config.PipelineConfig, callbacks *gstr
close(startRecording)
}),
filenameReplacements: make(map[string]string),
errors: make(chan error, 2),
writers: make(map[string]*sdk.AppWriter),
startRecording: startRecording,
endRecording: make(chan struct{}),
Expand Down Expand Up @@ -171,7 +173,8 @@ func (s *SDKSource) joinRoom() error {
switch s.RequestType {
case types.RequestTypeParticipant:
fileIdentifier = s.Identity
w, h, err = s.awaitParticipant(s.Identity)
s.filenameReplacements["{publisher_identity}"] = s.Identity
w, h, err = s.awaitParticipantTracks(s.Identity)

case types.RequestTypeTrackComposite:
fileIdentifier = s.Info.RoomName
Expand Down Expand Up @@ -200,42 +203,56 @@ func (s *SDKSource) joinRoom() error {
return nil
}

func (s *SDKSource) awaitParticipant(identity string) (uint32, uint32, error) {
s.errors = make(chan error, 2)

func (s *SDKSource) awaitParticipantTracks(identity string) (uint32, uint32, error) {
rp, err := s.getParticipant(identity)
if err != nil {
return 0, 0, err
}

pubs := rp.TrackPublications()
for _, t := range pubs {
if err = s.subscribe(t); err != nil {
return 0, 0, err
expected := 0
for _, pub := range pubs {
if shouldSubscribe(pub) {
expected++
}
}

for trackCount := 0; trackCount == 0 || trackCount < len(pubs); trackCount++ {
// await all expected subscriptions
for trackCount := 0; trackCount < expected; trackCount++ {
select {
case err = <-s.errors:
if err != nil {
return 0, 0, err
}
case <-s.endRecording:
return 0, 0, nil
}
}

var w, h uint32
for _, t := range pubs {
if t.TrackInfo().Type == livekit.TrackType_VIDEO {
w = t.TrackInfo().Width
h = t.TrackInfo().Height
// lock any incoming subscriptions
s.subLock.Lock()
defer s.subLock.Unlock()

for {
select {
// check errors from any tracks published in the meantime
case err = <-s.errors:
if err != nil {
return 0, 0, err
}
default:
// get dimensions after subscribing so that track info exists
var w, h uint32
for _, t := range pubs {
if t.TrackInfo().Type == livekit.TrackType_VIDEO && t.IsSubscribed() {
w = t.TrackInfo().Width
h = t.TrackInfo().Height
}
}

// ready
s.initialized.Break()
return w, h, nil
}
}

s.initialized.Break()
return w, h, nil
}

func (s *SDKSource) getParticipant(identity string) (*lksdk.RemoteParticipant, error) {
Expand All @@ -253,7 +270,6 @@ func (s *SDKSource) getParticipant(identity string) (*lksdk.RemoteParticipant, e

func (s *SDKSource) awaitTracks(expecting map[string]struct{}) (uint32, uint32, error) {
trackCount := len(expecting)
s.errors = make(chan error, trackCount)

deadline := time.After(subscriptionTimeout)
tracks, err := s.subscribeToTracks(expecting, deadline)
Expand All @@ -263,7 +279,7 @@ func (s *SDKSource) awaitTracks(expecting map[string]struct{}) (uint32, uint32,

for i := 0; i < trackCount; i++ {
select {
case err := <-s.errors:
case err = <-s.errors:
if err != nil {
return 0, 0, err
}
Expand Down Expand Up @@ -334,7 +350,10 @@ func (s *SDKSource) subscribe(track lksdk.TrackPublication) error {
// ----- Callbacks -----

func (s *SDKSource) onTrackSubscribed(track *webrtc.TrackRemote, pub *lksdk.RemoteTrackPublication, rp *lksdk.RemoteParticipant) {
s.subLock.RLock()

if s.initialized.IsBroken() && s.RequestType != types.RequestTypeParticipant {
s.subLock.RUnlock()
return
}

Expand All @@ -347,6 +366,7 @@ func (s *SDKSource) onTrackSubscribed(track *webrtc.TrackRemote, pub *lksdk.Remo
} else {
s.errors <- onSubscribeErr
}
s.subLock.RUnlock()
}()

s.active.Inc()
Expand Down Expand Up @@ -378,9 +398,7 @@ func (s *SDKSource) onTrackSubscribed(track *webrtc.TrackRemote, pub *lksdk.Remo
s.writers[ts.TrackID] = writer
s.mu.Unlock()

if s.initialized.IsBroken() {
s.callbacks.OnTrackAdded(ts)
} else {
if !s.initialized.IsBroken() {
s.AudioTrack = ts
}

Expand All @@ -407,9 +425,7 @@ func (s *SDKSource) onTrackSubscribed(track *webrtc.TrackRemote, pub *lksdk.Remo
s.writers[ts.TrackID] = writer
s.mu.Unlock()

if s.initialized.IsBroken() {
s.callbacks.OnTrackAdded(ts)
} else {
if !s.initialized.IsBroken() {
s.VideoTrack = ts
}

Expand All @@ -418,12 +434,12 @@ func (s *SDKSource) onTrackSubscribed(track *webrtc.TrackRemote, pub *lksdk.Remo
return
}

if !s.initialized.IsBroken() {
if s.initialized.IsBroken() {
<-s.callbacks.BuildReady
s.callbacks.OnTrackAdded(ts)
} else {
s.mu.Lock()
switch s.RequestType {
case types.RequestTypeParticipant:
s.filenameReplacements["{publisher_identity}"] = s.Identity

case types.RequestTypeTrackComposite:
if s.Identity == "" || track.Kind() == webrtc.RTPCodecTypeVideo {
s.Identity = rp.Identity()
Expand Down Expand Up @@ -482,19 +498,25 @@ func (s *SDKSource) createWriter(
}

func (s *SDKSource) onTrackPublished(pub *lksdk.RemoteTrackPublication, rp *lksdk.RemoteParticipant) {
if rp.Identity() != s.Identity {
if rp.Identity() != s.Identity || s.RequestType != types.RequestTypeParticipant {
return
}

switch pub.Source() {
case livekit.TrackSource_CAMERA, livekit.TrackSource_MICROPHONE:
if shouldSubscribe(pub) {
if err := s.subscribe(pub); err != nil {
logger.Errorw("failed to subscribe to track", err, "trackID", pub.SID())
}
} else {
logger.Infow("ignoring participant track", "reason", fmt.Sprintf("source %s", pub.Source()))
}
}

func shouldSubscribe(pub lksdk.TrackPublication) bool {
switch pub.Source() {
case livekit.TrackSource_CAMERA, livekit.TrackSource_MICROPHONE:
return true
default:
logger.Infow("ignoring participant track",
"reason", fmt.Sprintf("source %s", pub.Source()))
return
return false
}
}

Expand Down
33 changes: 19 additions & 14 deletions test/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,24 @@ func (r *Runner) publish(t *testing.T, codec types.MimeType, done chan struct{})
}

func (r *Runner) startEgress(t *testing.T, req *rpc.StartEgressRequest) string {
info := r.sendRequest(t, req)

// check status
if r.HealthPort != 0 {
status := r.getStatus(t)
require.Contains(t, status, info.EgressId)
}

// wait
time.Sleep(time.Second * 5)

// check active update
r.checkUpdate(t, info.EgressId, livekit.EgressStatus_EGRESS_ACTIVE)

return info.EgressId
}

func (r *Runner) sendRequest(t *testing.T, req *rpc.StartEgressRequest) *livekit.EgressInfo {
// send start request
info, err := r.client.StartEgress(context.Background(), "", req)

Expand All @@ -221,20 +239,7 @@ func (r *Runner) startEgress(t *testing.T, req *rpc.StartEgressRequest) string {
}

require.Equal(t, livekit.EgressStatus_EGRESS_STARTING.String(), info.Status.String())

// check status
if r.HealthPort != 0 {
status := r.getStatus(t)
require.Contains(t, status, info.EgressId)
}

// wait
time.Sleep(time.Second * 5)

// check active update
r.checkUpdate(t, info.EgressId, livekit.EgressStatus_EGRESS_ACTIVE)

return info.EgressId
return info
}

func (r *Runner) checkUpdate(t *testing.T, egressID string, status livekit.EgressStatus) *livekit.EgressInfo {
Expand Down
Loading

0 comments on commit a9aaeeb

Please sign in to comment.