diff --git a/callback.go b/callback.go index af1d63e..d7a2c59 100644 --- a/callback.go +++ b/callback.go @@ -26,6 +26,10 @@ import ( type ParticipantAttributesChangedFunc func(changed map[string]string, p Participant) type ParticipantCallback struct { + // for local participant + OnLocalTrackPublished func(publication *LocalTrackPublication, lp *LocalParticipant) + OnLocalTrackUnpublished func(publication *LocalTrackPublication, lp *LocalParticipant) + // for all participants OnTrackMuted func(pub TrackPublication, p Participant) OnTrackUnmuted func(pub TrackPublication, p Participant) @@ -46,6 +50,9 @@ type ParticipantCallback struct { func NewParticipantCallback() *ParticipantCallback { return &ParticipantCallback{ + OnLocalTrackPublished: func(publication *LocalTrackPublication, lp *LocalParticipant) {}, + OnLocalTrackUnpublished: func(publication *LocalTrackPublication, lp *LocalParticipant) {}, + OnTrackMuted: func(pub TrackPublication, p Participant) {}, OnTrackUnmuted: func(pub TrackPublication, p Participant) {}, OnMetadataChanged: func(oldMetadata string, p Participant) {}, @@ -63,6 +70,12 @@ func NewParticipantCallback() *ParticipantCallback { } func (cb *ParticipantCallback) Merge(other *ParticipantCallback) { + if other.OnLocalTrackPublished != nil { + cb.OnLocalTrackPublished = other.OnLocalTrackPublished + } + if other.OnLocalTrackUnpublished != nil { + cb.OnLocalTrackUnpublished = other.OnLocalTrackUnpublished + } if other.OnTrackMuted != nil { cb.OnTrackMuted = other.OnTrackMuted } diff --git a/engine.go b/engine.go index 7ff0176..3cbdce0 100644 --- a/engine.go +++ b/engine.go @@ -62,18 +62,19 @@ type RTCEngine struct { JoinTimeout time.Duration // callbacks - OnDisconnected func(reason DisconnectionReason) - OnMediaTrack func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) - OnParticipantUpdate func([]*livekit.ParticipantInfo) - OnSpeakersChanged func([]*livekit.SpeakerInfo) - OnDataReceived func(userPacket *livekit.UserPacket) // Deprecated: Use OnDataPacket instead - OnDataPacket func(identity string, dataPacket DataPacket) - OnConnectionQuality func([]*livekit.ConnectionQualityInfo) - OnRoomUpdate func(room *livekit.Room) - OnRestarting func() - OnRestarted func(*livekit.JoinResponse) - OnResuming func() - OnResumed func() + OnLocalTrackUnpublished func(response *livekit.TrackUnpublishedResponse) + OnDisconnected func(reason DisconnectionReason) + OnMediaTrack func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) + OnParticipantUpdate func([]*livekit.ParticipantInfo) + OnSpeakersChanged func([]*livekit.SpeakerInfo) + OnDataReceived func(userPacket *livekit.UserPacket) // Deprecated: Use OnDataPacket instead + OnDataPacket func(identity string, dataPacket DataPacket) + OnConnectionQuality func([]*livekit.ConnectionQualityInfo) + OnRoomUpdate func(room *livekit.Room) + OnRestarting func() + OnRestarted func(*livekit.JoinResponse) + OnResuming func() + OnResumed func() } func NewRTCEngine() *RTCEngine { @@ -95,6 +96,7 @@ func NewRTCEngine() *RTCEngine { } } e.client.OnLocalTrackPublished = e.handleLocalTrackPublished + e.client.OnLocalTrackUnpublished = e.handleLocalTrackUnpublished e.client.OnConnectionQuality = func(cqi []*livekit.ConnectionQualityInfo) { if f := e.OnConnectionQuality; f != nil { f(cqi) @@ -474,6 +476,12 @@ func (e *RTCEngine) handleLocalTrackPublished(res *livekit.TrackPublishedRespons e.trackPublishedChan <- res } +func (e *RTCEngine) handleLocalTrackUnpublished(res *livekit.TrackUnpublishedResponse) { + if e.OnLocalTrackUnpublished != nil { + e.OnLocalTrackUnpublished(res) + } +} + func (e *RTCEngine) handleDataPacket(msg webrtc.DataChannelMessage) { packet, err := e.readDataPacket(msg) if err != nil { diff --git a/localparticipant.go b/localparticipant.go index 2ab2053..c4c4721 100644 --- a/localparticipant.go +++ b/localparticipant.go @@ -120,6 +120,9 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl pub.updateInfo(pubRes.Track) p.addPublication(pub) + p.Callback.OnLocalTrackPublished(pub, p) + p.roomCallback.OnLocalTrackPublished(pub, p) + p.engine.log.Infow("published track", "name", opts.Name, "source", opts.Source.String(), "trackID", pubRes.Track.Sid) return pub, nil @@ -231,6 +234,9 @@ func (p *LocalParticipant) PublishSimulcastTrack(tracks []*LocalTrack, opts *Tra publisher.Negotiate() + p.Callback.OnLocalTrackPublished(pub, p) + p.roomCallback.OnLocalTrackPublished(pub, p) + p.engine.log.Infow("published simulcast track", "name", opts.Name, "source", opts.Source.String(), "trackID", pubRes.Track.Sid) return pub, nil @@ -245,6 +251,11 @@ func (p *LocalParticipant) republishTracks() { localPubs = append(localPubs, track) } p.tracks.Delete(key) + p.audioTracks.Delete(key) + p.videoTracks.Delete(key) + + p.Callback.OnLocalTrackUnpublished(track, p) + p.roomCallback.OnLocalTrackUnpublished(track, p) return true }) @@ -266,11 +277,14 @@ func (p *LocalParticipant) republishTracks() { func (p *LocalParticipant) closeTracks() { var localPubs []*LocalTrackPublication - p.tracks.Range(func(_, value interface{}) bool { + p.tracks.Range(func(key, value interface{}) bool { track := value.(*LocalTrackPublication) if track.Track() != nil || len(track.simulcastTracks) > 0 { localPubs = append(localPubs, track) } + p.tracks.Delete(key) + p.audioTracks.Delete(key) + p.videoTracks.Delete(key) return true }) @@ -408,6 +422,9 @@ func (p *LocalParticipant) UnpublishTrack(sid string) error { pub.CloseTrack() + p.Callback.OnLocalTrackUnpublished(pub, p) + p.roomCallback.OnLocalTrackUnpublished(pub, p) + p.engine.log.Infow("unpublished track", "name", pub.Name(), "sid", sid) return err