Skip to content

Commit

Permalink
conn.go: Include selected candidate pair to Addr
Browse files Browse the repository at this point in the history
  • Loading branch information
lactyy committed Sep 23, 2024
1 parent 9b4af1d commit 1b864df
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 74 deletions.
111 changes: 57 additions & 54 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ func (*Conn) SetWriteDeadline(time.Time) error {
func (c *Conn) LocalAddr() net.Addr {
addr := c.local
addr.ConnectionID = c.id
pair, _ := c.ice.GetSelectedCandidatePair()
if pair != nil {
addr.SelectedCandidate = pair.Local
}
return &addr
}

Expand All @@ -154,6 +158,11 @@ func (c *Conn) RemoteAddr() net.Addr {
c.candidatesMu.Lock()
addr.Candidates = slices.Clone(c.candidates)
c.candidatesMu.Unlock()

pair, _ := c.ice.GetSelectedCandidatePair()
if pair != nil {
addr.SelectedCandidate = pair.Remote
}
return addr
}

Expand All @@ -176,9 +185,15 @@ func (c *Conn) Close() (err error) {

c.negotiator.handleClose(c)

if c.reliable != nil {
err = c.reliable.Close()
}
if c.unreliable != nil {
err = errors.Join(err, c.unreliable.Close())
}

err = errors.Join(
c.reliable.Close(),
c.unreliable.Close(),
err,
c.sctp.Stop(),
c.dtls.Stop(),
c.ice.Stop(),
Expand Down Expand Up @@ -254,7 +269,6 @@ func (c *Conn) handleSignal(signal *Signal) error {
Typ: webrtc.ICECandidateType(candidate.Type()),
TCPType: candidate.TCPType().String(),
}

if r := candidate.RelatedAddress(); r != nil {
i.RelatedAddress, i.RelatedPort = r.Address, uint16(r.Port)
}
Expand Down Expand Up @@ -323,9 +337,9 @@ func parseDescription(d *sdp.SessionDescription) (*description, error) {
}
var role webrtc.DTLSRole
switch attr {
case "active":
case sdp.ConnectionRoleActive.String():
role = webrtc.DTLSRoleClient
case "actpass":
case sdp.ConnectionRoleActpass.String():
role = webrtc.DTLSRoleServer
default:
return nil, fmt.Errorf("invalid setup attribute: %s", attr)
Expand Down Expand Up @@ -360,14 +374,6 @@ func parseDescription(d *sdp.SessionDescription) (*description, error) {
}, nil
}

// description contains parameters for calling the Start method of ICE, DTLS and SCTP transport.
//
// A description may be parsed by a negotiator (Listener or Dialer) using parseDescription
// with a [sdp.SessionDescription] decoded from a Signal of SignalTypeOffer or SignalTypeAnswer.
//
// A description may be filled in by a negotiator (Listener or Dialer) to encode
// a local description of a Conn.

// description contains parameters necessary for starting ICE, DTLS, and SCTP transport within a Conn.
//
// It may be created by parsing a [sdp.SessionDescription] signaled from a remote connection or filed
Expand All @@ -386,7 +392,7 @@ type description struct {
// parameters of each transport within a Conn.
func (desc description) encode() ([]byte, error) {
d := &sdp.SessionDescription{
Version: 0x2,
Version: 0x0,
Origin: sdp.Origin{
Username: "-",
SessionID: rand.Uint64(),
Expand All @@ -400,53 +406,50 @@ func (desc description) encode() ([]byte, error) {
{},
},
Attributes: []sdp.Attribute{
{Key: "group", Value: "BUNDLE 0"},
{Key: "extmap-allow-mixed", Value: ""},
{Key: "msid-semantic", Value: " WMS"},
{Key: sdp.AttrKeyGroup, Value: "BUNDLE 0"},
sdp.NewPropertyAttribute(sdp.AttrKeyExtMapAllowMixed),
{Key: sdp.AttrKeyMsidSemantic, Value: " WMS"},
},
MediaDescriptions: []*sdp.MediaDescription{
{
MediaName: sdp.MediaName{
Media: "application",
Port: sdp.RangedPort{Value: 9},
Protos: []string{"UDP", "DTLS", "SCTP"},
Formats: []string{"webrtc-datachannel"},
},
ConnectionInformation: &sdp.ConnectionInformation{
NetworkType: "IN",
AddressType: "IP4",
Address: &sdp.Address{Address: "0.0.0.0"},
},
Attributes: []sdp.Attribute{
{Key: "ice-ufrag", Value: desc.ice.UsernameFragment},
{Key: "ice-pwd", Value: desc.ice.Password},
{Key: "ice-options", Value: "trickle"},
{Key: "fingerprint", Value: fmt.Sprintf("%s %s",
desc.dtls.Fingerprints[0].Algorithm,
desc.dtls.Fingerprints[0].Value,
)},
desc.setupAttribute(),
{Key: "mid", Value: "0"},
{Key: "sctp-port", Value: "5000"},
{Key: "max-message-size", Value: strconv.FormatUint(uint64(desc.sctp.MaxMessageSize), 10)},
},
}

media := &sdp.MediaDescription{
MediaName: sdp.MediaName{
Media: "application",
Port: sdp.RangedPort{Value: 9},
Protos: []string{"UDP", "DTLS", "SCTP"},
Formats: []string{"webrtc-datachannel"},
},
ConnectionInformation: &sdp.ConnectionInformation{
NetworkType: "IN",
AddressType: "IP4",
Address: &sdp.Address{
Address: "0.0.0.0",
},
},
}
return d.Marshal()
media.WithICECredentials(desc.ice.UsernameFragment, desc.ice.Password)
media.WithValueAttribute("ice-options", "trickle")
for _, fingerprint := range desc.dtls.Fingerprints {
media.WithFingerprint(fingerprint.Algorithm, fingerprint.Value)
}
media.WithValueAttribute(sdp.AttrKeyConnectionSetup, desc.connectionRole(desc.dtls.Role).String())
media.WithValueAttribute(sdp.AttrKeyMID, "0")
media.WithValueAttribute("sctp-port", "5000")
media.WithValueAttribute("max-message-size", strconv.FormatUint(uint64(desc.sctp.MaxMessageSize), 10))

return d.WithMedia(media).Marshal()
}

// setupAttribute returns a [sdp.Attribute] with the key 'setup' indicating the local
// DTLS role as either 'active' or 'actpass'. It is called by encode to include the role
// in the media description of the local [sdp.SessionDescription].
func (desc description) setupAttribute() sdp.Attribute {
attr := sdp.Attribute{Key: "setup"}
if desc.dtls.Role == webrtc.DTLSRoleServer {
attr.Value = "actpass"
} else {
attr.Value = "active"
// connectionRole returns a [sdp.ConnectionRole] indicating the local DTLS role. It is called
// by encode to include the role into the media description of local [sdp.SessionDescription]
// as a [sdp.Attribute] of 'setup'.
func (desc description) connectionRole(role webrtc.DTLSRole) sdp.ConnectionRole {
switch role {
case webrtc.DTLSRoleServer:
return sdp.ConnectionRoleActpass
default:
return sdp.ConnectionRoleActive
}
return attr
}

// newConn creates a Conn from the ICE, DTLS and SCTP transport associated with the IDs.
Expand Down
16 changes: 8 additions & 8 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ type Dialer struct {
// has been negotiated by Dialer.
Log *slog.Logger

// API specifies custom configuration for WebRTC transports, and data channels. If left as nil, a new [webrtc.API]
// will be set from [webrtc.NewAPI]. The webrtc.SettingEngine of the API should not allow detaching data channels
// (by calling [webrtc.SettingEngine.DetachDataChannels]) as it requires additional steps on the Conn.

// API specifies custom configuration for WebRTC transports and data channels. If nil, a new [webrtc.API] will be
// set from [webrtc.NewAPI]. The [webrtc.SettingEngine] of the API should not allow detaching data channels, as it requires
// additional steps on the Conn (which cannot be determined by the Conn).
Expand Down Expand Up @@ -248,17 +244,21 @@ func (d Dialer) startTransports(ctx context.Context, conn *Conn, desc *descripti
if err := withContext(ctx, func() error {
var err error
conn.reliable, err = d.API.NewDataChannel(conn.sctp, &webrtc.DataChannelParameters{
Label: "ReliableDataChannel",
Label: "ReliableDataChannel",
Ordered: true,
})
return err
}); err != nil {
return fmt.Errorf("create ReliableDataChannel: %w", err)
}
if err := withContext(ctx, func() error {
var err error
var (
err error
maxRetransmits uint16 = 0
)
conn.unreliable, err = d.API.NewDataChannel(conn.sctp, &webrtc.DataChannelParameters{
Label: "UnreliableDataChannel",
Ordered: false,
Label: "UnreliableDataChannel",
MaxRetransmits: &maxRetransmits,
})
return err
}); err != nil {
Expand Down
38 changes: 26 additions & 12 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ type Listener struct {

stop func()
closed chan struct{}
once sync.Once
}

// Accept waits for and returns the next [Conn] to the Listener. An error may be
Expand Down Expand Up @@ -113,6 +112,12 @@ type Addr struct {
// signaled from a remote connection. ICE candidates are used to determine the UDP/TCP addresses
// for establishing ICE transport and can be used to determine the network address of the connection.
Candidates []webrtc.ICECandidate

// SelectedCandidate is the candidate selected to connect with the ICE transport within a Conn.
// An ICE candidate may be used to determine the UDP/TCP address of the connection. It may be nil
// if the Conn has been closed, or if the Conn has encountered an error when obtaining the selected
// ICE candidate pair.
SelectedCandidate *webrtc.ICECandidate
}

// String formats the Addr as a string.
Expand All @@ -125,6 +130,12 @@ func (addr *Addr) String() string {
b.WriteString(strconv.FormatUint(addr.ConnectionID, 10))
b.WriteByte(')')
}
if addr.SelectedCandidate != nil {
b.WriteByte(' ')
b.WriteByte('(')
b.WriteString(addr.SelectedCandidate.String())
b.WriteByte(')')
}
return b.String()
}

Expand Down Expand Up @@ -332,7 +343,7 @@ func (l *Listener) handleConn(conn *Conn, d *description) {
var err error
defer func() {
if err != nil {
l.connections.Delete(conn.remoteAddr().String()) // Stop notifying for the Conn.
_ = conn.Close() // Stop notifying for the Conn.

if errors.Is(err, context.DeadlineExceeded) {
if err := l.signaling.Signal(&Signal{
Expand Down Expand Up @@ -368,6 +379,9 @@ func (l *Listener) handleConn(conn *Conn, d *description) {
select {
case <-ctx.Done():
err = ctx.Err()
case <-l.closed:
case <-conn.closed:
return
case <-conn.candidateReceived:
conn.log.Debug("received first candidate")
if err = l.startTransports(ctx, conn, d); err != nil {
Expand Down Expand Up @@ -398,21 +412,18 @@ func (l *Listener) startTransports(ctx context.Context, conn *Conn, d *descripti
}

conn.log.Debug("starting SCTP transport")
var (
once = new(sync.Once)
opened = make(chan struct{}, 1)
)
opened := make(chan struct{}, 1)
conn.sctp.OnDataChannelOpened(func(channel *webrtc.DataChannel) {
switch channel.Label() {
case "ReliableDataChannel":
conn.reliable = channel
case "UnreliableDataChannel":
conn.unreliable = channel
default:
return
}
if conn.reliable != nil && conn.unreliable != nil {
once.Do(func() {
close(opened)
})
close(opened)
}
})
if err := withContext(ctx, func() error {
Expand Down Expand Up @@ -449,12 +460,15 @@ func withContext(ctx context.Context, f func() error) error {

// Close closes the Listener, ensuring that any blocking methods will return [net.ErrClosed] as an error.
func (l *Listener) Close() error {
l.once.Do(func() {
select {
case <-l.closed:
return nil
default:
close(l.closed)
close(l.incoming)
l.stop()
})
return nil
return nil
}
}

// A signalError may be returned by the methods of Listener to handle incoming Signals signaled from the
Expand Down

0 comments on commit 1b864df

Please sign in to comment.