diff --git a/callback.go b/callback.go index ee130016..fffb2ebf 100644 --- a/callback.go +++ b/callback.go @@ -34,7 +34,8 @@ type ParticipantCallback struct { OnTrackSubscriptionFailed func(sid string, rp *RemoteParticipant) OnTrackPublished func(publication *RemoteTrackPublication, rp *RemoteParticipant) OnTrackUnpublished func(publication *RemoteTrackPublication, rp *RemoteParticipant) - OnDataReceived func(data []byte, params DataReceiveParams) + OnDataReceived func(data []byte, params DataReceiveParams) // Deprecated: Use OnDataPacket instead + OnDataPacket func(data DataPacket, params DataReceiveParams) } func NewParticipantCallback() *ParticipantCallback { @@ -50,6 +51,7 @@ func NewParticipantCallback() *ParticipantCallback { OnTrackPublished: func(publication *RemoteTrackPublication, rp *RemoteParticipant) {}, OnTrackUnpublished: func(publication *RemoteTrackPublication, rp *RemoteParticipant) {}, OnDataReceived: func(data []byte, params DataReceiveParams) {}, + OnDataPacket: func(data DataPacket, params DataReceiveParams) {}, } } @@ -87,6 +89,9 @@ func (cb *ParticipantCallback) Merge(other *ParticipantCallback) { if other.OnDataReceived != nil { cb.OnDataReceived = other.OnDataReceived } + if other.OnDataPacket != nil { + cb.OnDataPacket = other.OnDataPacket + } } type RoomCallback struct { diff --git a/data.go b/data.go index 6d1977ee..14e75797 100644 --- a/data.go +++ b/data.go @@ -1,7 +1,7 @@ package lksdk type dataPublishOptions struct { - Reliable bool + Reliable *bool DestinationIdentities []string Topic string } @@ -9,7 +9,7 @@ type dataPublishOptions struct { type DataReceiveParams struct { Sender *RemoteParticipant SenderIdentity string - Topic string + Topic string // Deprecated: Use UserDataPacket.Topic } type DataPublishOption func(*dataPublishOptions) @@ -22,7 +22,7 @@ func WithDataPublishTopic(topic string) DataPublishOption { func WithDataPublishReliable(reliable bool) DataPublishOption { return func(o *dataPublishOptions) { - o.Reliable = reliable + o.Reliable = &reliable } } diff --git a/engine.go b/engine.go index 33691eff..8d2e8a4e 100644 --- a/engine.go +++ b/engine.go @@ -64,7 +64,8 @@ type RTCEngine struct { OnMediaTrack func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) OnParticipantUpdate func([]*livekit.ParticipantInfo) OnSpeakersChanged func([]*livekit.SpeakerInfo) - OnDataReceived func(userPacket *livekit.UserPacket) + OnDataReceived func(userPacket *livekit.UserPacket) // Deprecated: Use OnDataPacket instead + OnDataPacket func(identity string, dataPacket DataPacket) OnConnectionQuality func([]*livekit.ConnectionQualityInfo) OnRoomUpdate func(room *livekit.Room) OnRestarting func() @@ -460,10 +461,34 @@ func (e *RTCEngine) handleDataPacket(msg webrtc.DataChannelMessage) { if err != nil { return } + identity := packet.ParticipantIdentity switch msg := packet.Value.(type) { case *livekit.DataPacket_User: - if e.OnDataReceived != nil { - e.OnDataReceived(msg.User) + m := msg.User + //lint:ignore SA1019 backward compatibility + if ptr := &m.ParticipantIdentity; *ptr == "" { + *ptr = identity + } + //lint:ignore SA1019 backward compatibility + if ptr := &m.DestinationIdentities; len(*ptr) == 0 { + *ptr = packet.DestinationIdentities + } + if onDataReceived := e.OnDataReceived; onDataReceived != nil { + onDataReceived(m) + } + if e.OnDataPacket != nil { + if identity == "" { + //lint:ignore SA1019 backward compatibility + identity = m.ParticipantIdentity + } + e.OnDataPacket(identity, &UserDataPacket{ + Payload: m.Payload, + Topic: m.GetTopic(), + }) + } + case *livekit.DataPacket_SipDtmf: + if e.OnDataPacket != nil { + e.OnDataPacket(identity, msg.SipDtmf) } } } diff --git a/go.mod b/go.mod index 2333b264..6908cadc 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/go-logr/stdr v1.2.2 github.com/gorilla/websocket v1.5.1 github.com/livekit/mediatransportutil v0.0.0-20240302142739-1c3dd691a1b8 - github.com/livekit/protocol v1.10.1 + github.com/livekit/protocol v1.11.0 github.com/magefile/mage v1.15.0 github.com/pion/dtls/v2 v2.2.10 github.com/pion/interceptor v0.1.25 @@ -40,7 +40,7 @@ require ( github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/lithammer/shortuuid/v4 v4.0.0 // indirect github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 // indirect - github.com/livekit/psrpc v0.5.3-0.20240227154351-b7f99eaaf7b3 // indirect + github.com/livekit/psrpc v0.5.3-0.20240228172457-3724cb4adbc4 // indirect github.com/mackerelio/go-osstat v0.2.4 // indirect github.com/nats-io/nats.go v1.31.0 // indirect github.com/nats-io/nkeys v0.4.6 // indirect diff --git a/go.sum b/go.sum index 1ccd1c23..152e6a3a 100644 --- a/go.sum +++ b/go.sum @@ -75,10 +75,10 @@ github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1 h1:jm09419p0lqTkD github.com/livekit/mageutil v0.0.0-20230125210925-54e8a70427c1/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= github.com/livekit/mediatransportutil v0.0.0-20240302142739-1c3dd691a1b8 h1:xawydPEACNO5Ncs2LgioTjWghXQ0eUN1q1RnVUUyVnI= github.com/livekit/mediatransportutil v0.0.0-20240302142739-1c3dd691a1b8/go.mod h1:jwKUCmObuiEDH0iiuJHaGMXwRs3RjrB4G6qqgkr/5oE= -github.com/livekit/protocol v1.10.1 h1:upe6pKRqH8wpsMuR2OLtgizEm94iia3pDYm3O4/2PRY= -github.com/livekit/protocol v1.10.1/go.mod h1:eWPz45pnxwpCwB84qqhHxG0bCRgasa2itN6GAHCDddc= -github.com/livekit/psrpc v0.5.3-0.20240227154351-b7f99eaaf7b3 h1:bvjzDR+Rvdf3JgzQMtLiGVHBQ8KoOWM7x7sHj79jevQ= -github.com/livekit/psrpc v0.5.3-0.20240227154351-b7f99eaaf7b3/go.mod h1:CQUBSPfYYAaevg1TNCc6/aYsa8DJH4jSRFdCeSZk5u0= +github.com/livekit/protocol v1.11.0 h1:3V1j0EGfh5T8A/rb/H7kB+ak9TINA8a/2jXpH+emLsg= +github.com/livekit/protocol v1.11.0/go.mod h1:XpJ2t2wFnnQghPpkxXAzMZhYMDnm8wWxdxYJK4fP9gM= +github.com/livekit/psrpc v0.5.3-0.20240228172457-3724cb4adbc4 h1:253WtQ2VGVHzIIzW9MUZj7vUDDILESU3zsEbiRdxYF0= +github.com/livekit/psrpc v0.5.3-0.20240228172457-3724cb4adbc4/go.mod h1:CQUBSPfYYAaevg1TNCc6/aYsa8DJH4jSRFdCeSZk5u0= github.com/mackerelio/go-osstat v0.2.4 h1:qxGbdPkFo65PXOb/F/nhDKpF2nGmGaCFDLXoZjJTtUs= github.com/mackerelio/go-osstat v0.2.4/go.mod h1:Zy+qzGdZs3A9cuIqmgbJvwbmLQH9dJvtio5ZjJTbdlQ= github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= diff --git a/integration_test.go b/integration_test.go index dcaa912c..01e256fc 100644 --- a/integration_test.go +++ b/integration_test.go @@ -31,6 +31,7 @@ import ( "go.uber.org/atomic" "github.com/livekit/protocol/livekit" + "github.com/livekit/server-sdk-go/v2/pkg/interceptor" ) @@ -93,8 +94,10 @@ func TestJoin(t *testing.T) { pub, err := createAgent(t.Name(), nil, "publisher") require.NoError(t, err) - var dataLock sync.Mutex - var receivedData string + var ( + dataLock sync.Mutex + receivedData string + ) audioTrackName := "audio_of_pub1" var trackLock sync.Mutex @@ -102,10 +105,13 @@ func TestJoin(t *testing.T) { subCB := &RoomCallback{ ParticipantCallback: ParticipantCallback{ - OnDataReceived: func(data []byte, params DataReceiveParams) { - dataLock.Lock() - receivedData = string(data) - dataLock.Unlock() + OnDataPacket: func(data DataPacket, params DataReceiveParams) { + switch data := data.(type) { + case *UserDataPacket: + dataLock.Lock() + receivedData = string(data.Payload) + dataLock.Unlock() + } }, OnTrackSubscribed: func(track *webrtc.TrackRemote, publication *RemoteTrackPublication, rp *RemoteParticipant) { trackLock.Lock() @@ -123,7 +129,8 @@ func TestJoin(t *testing.T) { require.NotNil(t, serverInfo) require.Equal(t, serverInfo.Edition, livekit.ServerInfo_Standard) - pub.LocalParticipant.PublishData([]byte("test"), WithDataPublishReliable(true)) + pub.LocalParticipant.PublishDataPacket(UserData([]byte("test"))) + pub.LocalParticipant.PublishDataPacket(&livekit.SipDTMF{Digit: "#"}) localPub := pubNullTrack(t, pub, audioTrackName) require.Equal(t, localPub.Name(), audioTrackName) @@ -203,7 +210,7 @@ func TestForceTLS(t *testing.T) { require.NoError(t, err) // ensure publisher connected - pub.LocalParticipant.PublishData([]byte("test"), WithDataPublishReliable(true)) + pub.LocalParticipant.PublishDataPacket(UserData([]byte("test"))) pub.Simulate(SimulateForceTLS) require.Eventually(t, func() bool { return reconnected.Load() && pub.engine.ensurePublisherConnected(true) == nil }, 15*time.Second, 100*time.Millisecond) @@ -249,7 +256,7 @@ func TestSubscribeMutedTrack(t *testing.T) { var trackReceived atomic.Int32 var pubTrackMuted sync.WaitGroup - require.NoError(t, pub.LocalParticipant.PublishData([]byte("test"), WithDataPublishReliable(true))) + require.NoError(t, pub.LocalParticipant.PublishDataPacket(UserData([]byte("test")))) pubMuteTrack := func(t *testing.T, room *Room, name string, codec webrtc.RTPCodecCapability) *LocalTrackPublication { pubTrackMuted.Add(1) diff --git a/localparticipant.go b/localparticipant.go index 730c8875..699fe8ef 100644 --- a/localparticipant.go +++ b/localparticipant.go @@ -269,42 +269,107 @@ func (p *LocalParticipant) closeTracks() { } } -func (p *LocalParticipant) PublishData( - payload []byte, - opts ...DataPublishOption, -) error { +func (p *LocalParticipant) publishData(kind livekit.DataPacket_Kind, dataPacket *livekit.DataPacket) error { + if err := p.engine.ensurePublisherConnected(true); err != nil { + return err + } + + encoded, err := proto.Marshal(dataPacket) + if err != nil { + return err + } + + return p.engine.GetDataChannel(kind).Send(encoded) +} + +// PublishData sends custom user data via WebRTC data channel. +// +// By default, the message can be received by all participants in a room, +// see WithDataPublishDestination for choosing specific participants. +// +// Messages are sent via a LOSSY channel by default, see WithDataPublishReliable for sending reliable data. +// +// Deprecated: Use PublishDataPacket with UserData instead. Note that it sends reliable packets by default. +func (p *LocalParticipant) PublishData(payload []byte, opts ...DataPublishOption) error { options := &dataPublishOptions{} for _, opt := range opts { opt(options) } - packet := &livekit.UserPacket{ - Payload: payload, - DestinationIdentities: options.DestinationIdentities, + if options.Reliable == nil { + // Old logic sends packets as lossy by default. + opts = append(opts, WithDataPublishReliable(false)) } - if options.Topic != "" { - packet.Topic = proto.String(options.Topic) + return p.PublishDataPacket(UserData(payload), opts...) +} + +type DataPacket interface { + ToProto() *livekit.DataPacket +} + +// Compile-time assertion for all supported data packet types. +var ( + _ DataPacket = (*UserDataPacket)(nil) + _ DataPacket = (*livekit.SipDTMF)(nil) // implemented in the protocol package +) + +// UserData is a custom user data that can be sent via WebRTC. +func UserData(data []byte) *UserDataPacket { + return &UserDataPacket{Payload: data} +} + +// UserDataPacket is a custom user data that can be sent via WebRTC on a custom topic. +type UserDataPacket struct { + Payload []byte + Topic string // optional +} + +// ToProto implements DataPacket. +func (p *UserDataPacket) ToProto() *livekit.DataPacket { + var topic *string + if p.Topic != "" { + topic = proto.String(p.Topic) } - dataPacket := &livekit.DataPacket{ - Value: &livekit.DataPacket_User{ - User: packet, + return &livekit.DataPacket{Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{ + Payload: p.Payload, + Topic: topic, }, + }} +} + +// PublishDataPacket sends a packet via a WebRTC data channel. UserData can be used for sending custom user data. +// +// By default, the message can be received by all participants in a room, +// see WithDataPublishDestination for choosing specific participants. +// +// Messages are sent via a RELIABLE channel, see WithDataPublishReliable for sending lossy data. +func (p *LocalParticipant) PublishDataPacket(pck DataPacket, opts ...DataPublishOption) error { + options := &dataPublishOptions{} + for _, opt := range opts { + opt(options) } - if options.Reliable { - dataPacket.Kind = livekit.DataPacket_RELIABLE - } else { - dataPacket.Kind = livekit.DataPacket_LOSSY + dataPacket := pck.ToProto() + if options.Topic != "" { + if u, ok := dataPacket.Value.(*livekit.DataPacket_User); ok && u.User != nil { + u.User.Topic = proto.String(options.Topic) + } } - - if err := p.engine.ensurePublisherConnected(true); err != nil { - return err + // New logic sends packets as reliable by default. + // This matches the default value of Kind on protobuf level. + kind := livekit.DataPacket_RELIABLE + if options.Reliable != nil && !*options.Reliable { + kind = livekit.DataPacket_LOSSY } + //lint:ignore SA1019 backward compatibility + dataPacket.Kind = kind - encoded, err := proto.Marshal(dataPacket) - if err != nil { - return err + dataPacket.DestinationIdentities = options.DestinationIdentities + if u, ok := dataPacket.Value.(*livekit.DataPacket_User); ok && u.User != nil { + //lint:ignore SA1019 backward compatibility + u.User.DestinationIdentities = options.DestinationIdentities } - return p.engine.GetDataChannel(dataPacket.Kind).Send(encoded) + return p.publishData(kind, dataPacket) } func (p *LocalParticipant) UnpublishTrack(sid string) error { diff --git a/room.go b/room.go index cedb62fe..84f302be 100644 --- a/room.go +++ b/room.go @@ -142,7 +142,7 @@ func NewRoom(callback *RoomCallback) *Room { engine.OnDisconnected = r.handleDisconnect engine.OnParticipantUpdate = r.handleParticipantUpdate engine.OnSpeakersChanged = r.handleSpeakersChange - engine.OnDataReceived = r.handleDataReceived + engine.OnDataPacket = r.handleDataReceived engine.OnConnectionQuality = r.handleConnectionQualityUpdate engine.OnRoomUpdate = r.handleRoomUpdate engine.OnRestarting = r.handleRestarting @@ -377,21 +377,28 @@ func (r *Room) handleResumed() { r.sendSyncState() } -func (r *Room) handleDataReceived(userPacket *livekit.UserPacket) { - if userPacket.ParticipantIdentity == r.LocalParticipant.Identity() { +func (r *Room) handleDataReceived(identity string, dataPacket DataPacket) { + if identity == r.LocalParticipant.Identity() { // if sent by itself, do not handle data return } - p := r.GetParticipantByIdentity(userPacket.ParticipantIdentity) + p := r.GetParticipantByIdentity(identity) params := DataReceiveParams{ - Topic: userPacket.GetTopic(), - SenderIdentity: userPacket.ParticipantIdentity, + SenderIdentity: identity, + Sender: p, + } + switch msg := dataPacket.(type) { + case *UserDataPacket: // compatibility + params.Topic = msg.Topic + if p != nil { + p.Callback.OnDataReceived(msg.Payload, params) + } + r.callback.OnDataReceived(msg.Payload, params) } if p != nil { - params.Sender = p - p.Callback.OnDataReceived(userPacket.Payload, params) + p.Callback.OnDataPacket(dataPacket, params) } - r.callback.OnDataReceived(userPacket.Payload, params) + r.callback.OnDataPacket(dataPacket, params) } func (r *Room) handleParticipantUpdate(participants []*livekit.ParticipantInfo) {