forked from df-mc/go-nethernet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
listener.go
523 lines (465 loc) · 17.9 KB
/
listener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
package nethernet
import (
"context"
"errors"
"fmt"
"github.com/df-mc/go-nethernet/internal"
"github.com/pion/sdp/v3"
"github.com/pion/webrtc/v4"
"log/slog"
"net"
"strconv"
"strings"
"sync"
"time"
)
// ListenConfig encapsulates options for creating a new Listener through [ListenConfig.Listen].
// It allows customizing logging, WebRTC API settings, and contexts for negotiations.
type ListenConfig struct {
// Log is used for logging messages at various levels. If nil, the default [slog.Logger] will be set from
// [slog.Default]. Log will be extended when a Conn is being accepted by [Listener.Accept] with additional
// attributes such as the connection ID and network ID, and will have a 'src' attribute set to 'listener'
// to mark that the Conn has been negotiated by Listener.
Log *slog.Logger
// API specifies custom configuration for WebRTC transports and data channels. If nil, a new [webrtc.API] will
// be set from [webrtc.NewAPI]. The [webrtc.SettingEngine] of the API should not allow detaching data channels,
// as it requires additional steps on the Conn (which cannot be determined by the Conn).
API *webrtc.API
// ConnContext provides a [context.Context] for starting the ICE, DTLS, and SCTP transports of the Conn. If nil,
// a default [context.Context] with 5 seconds timeout will be used. The parent [context.Context] may be used to
// create a [context.Context] to be returned (likely using [context.WithCancel] or [context.WithTimeout]).
ConnContext func(parent context.Context, conn *Conn) context.Context
// NegotiationContext provides a [context.Context] for the negotiation. If nil, a default [context.Context]
// with 5 seconds timeout will be used. The parent [context.Context] may be used to create a [context.Context]
// to be returned (likely using [context.WithCancel] or [context.WithTimeout]). If the deadline of the context
// is exceeded, a Signal of SignalTypeError with ErrorCodeNegotiationTimeoutWaitingForAccept will be signaled back.
NegotiationContext func(parent context.Context) context.Context
}
// Listen listens on the local network ID specified by the Signaling implementation. It returns a Listener
// that may be used to accept established connections from [Listener.Accept]. Signaling will be used to notify
// incoming Signals from remote connections.
func (conf ListenConfig) Listen(signaling Signaling) (*Listener, error) {
if conf.Log == nil {
conf.Log = slog.Default()
}
if conf.API == nil {
conf.API = webrtc.NewAPI()
}
l := &Listener{
conf: conf,
signaling: signaling,
networkID: signaling.NetworkID(),
incoming: make(chan *Conn),
closed: make(chan struct{}),
}
l.stop = signaling.Notify(listenerNotifier{l})
return l, nil
}
// Listener implements a NetherNet connection listener.
type Listener struct {
conf ListenConfig
signaling Signaling
networkID uint64
connections sync.Map
incoming chan *Conn
stop func()
closed chan struct{}
}
// Accept waits for and returns the next [Conn] to the Listener. An error may be
// returned, if the Listener has been closed.
func (l *Listener) Accept() (net.Conn, error) {
select {
case <-l.closed:
return nil, net.ErrClosed
case conn := <-l.incoming:
return conn, nil
}
}
// Addr returns an Addr that represents the local network ID of the Listener.
func (l *Listener) Addr() net.Addr {
return &Addr{NetworkID: l.networkID}
}
// Addr represents a network address that encapsulates both local and remote connection
// IDs and implements [net.Addr].
//
// The Addr provides details for the unique IDs of Conn and ICE Candidates used for establishing
// network connectivity.
type Addr struct {
// ConnectionID is a unique ID assigned to a connection. It is generated by the client and
// used in Signals signaled between clients and servers to uniquely reference a specific connection.
ConnectionID uint64
// NetworkID is a unique ID for the NetherNet network.
NetworkID uint64
// Candidates contains a list of ICE candidates. These candidates are either gathered locally or
// signaled from a remote connection. ICE candidates are used to determine the UDP/TCP addresses
// for establishing ICE transport and can be used to determine the network address of the connection.
Candidates []webrtc.ICECandidate
// SelectedCandidate is the candidate selected to connect with the ICE transport within a Conn.
// An ICE candidate may be used to determine the UDP/TCP address of the connection. It may be nil
// if the Conn has been closed, or if the Conn has encountered an error when obtaining the selected
// ICE candidate pair.
SelectedCandidate *webrtc.ICECandidate
}
// String formats the Addr as a string.
func (addr *Addr) String() string {
b := &strings.Builder{}
b.WriteString(strconv.FormatUint(addr.NetworkID, 10))
b.WriteByte(' ')
if addr.ConnectionID != 0 {
b.WriteByte('(')
b.WriteString(strconv.FormatUint(addr.ConnectionID, 10))
b.WriteByte(')')
}
if addr.SelectedCandidate != nil {
b.WriteByte(' ')
b.WriteByte('(')
b.WriteString(addr.SelectedCandidate.String())
b.WriteByte(')')
}
return b.String()
}
// Network returns the network type for the Addr, which is always 'nethernet'.
func (addr *Addr) Network() string { return "nethernet" }
// ID returns the network ID of Listener.
func (l *Listener) ID() int64 { return int64(l.networkID) }
// PongData is a stub.
func (l *Listener) PongData([]byte) {}
// listenerNotifier receives notifications for a Listener. It is registered to a Signaling
// implementation by [ListenConfig.Listen] to create a new Listener.
type listenerNotifier struct{ *Listener }
// NotifySignal notifies an incoming Signal to the Listener. It handles Signals of different
// types by calling the corresponding methods for each type. If an error occurs while handling
// the Signal, it attempts to cast the error as a signalError and if matches, it signals back
// a Signal of type SignalTypeError with the error code.
func (l listenerNotifier) NotifySignal(signal *Signal) {
var err error
switch signal.Type {
case SignalTypeOffer:
err = l.handleOffer(signal)
default:
err = l.handleSignal(signal)
}
if err != nil {
var s *signalError
if errors.As(err, &s) {
if err := l.signaling.Signal(&Signal{
Type: SignalTypeError,
ConnectionID: signal.ConnectionID,
Data: strconv.FormatUint(uint64(s.code), 10),
NetworkID: signal.NetworkID,
}); err != nil {
l.conf.Log.Error("error signaling error", internal.ErrAttr(err))
}
}
l.conf.Log.Error("error handling signal", slog.Any("signal", signal), internal.ErrAttr(err))
}
}
// NotifyError notifies the Listener of an error that occurred in the Signaling implementation.
// If the error is [ErrSignalingStopped], it will also close the Listener.
func (l listenerNotifier) NotifyError(err error) {
l.conf.Log.Error("notified error in signaling", internal.ErrAttr(err))
if errors.Is(err, ErrSignalingStopped) {
_ = l.Close()
}
}
// handleOffer handles an incoming Signal of SignalTypeOffer. It parses the data of Signal into [sdp.SessionDescription]
// and transforms into remote description for later use in negotiation. An answer will be created from local parameters of
// each transport and signaled back to the remote connection referenced in the offer.
func (l *Listener) handleOffer(signal *Signal) error {
d := &sdp.SessionDescription{}
if err := d.UnmarshalString(signal.Data); err != nil {
return wrapSignalError(fmt.Errorf("decode offer: %w", err), ErrorCodeFailedToSetRemoteDescription)
}
desc, err := parseDescription(d)
if err != nil {
return wrapSignalError(fmt.Errorf("parse offer: %w", err), ErrorCodeFailedToSetRemoteDescription)
}
var (
ctx context.Context
parent = listenerContext{closed: l.closed}
)
if l.conf.NegotiationContext != nil {
if ctx = l.conf.NegotiationContext(parent); ctx == nil {
panic("nethernet: Listener: NegotiationContext returned nil")
}
} else {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(parent, time.Second*15)
defer cancel()
}
credentials, err := l.signaling.Credentials(ctx)
if err != nil {
return wrapSignalError(fmt.Errorf("obtain credentials: %w", err), ErrorCodeSignalingTurnAuthFailed)
}
gatherer, err := l.conf.API.NewICEGatherer(gatherOptions(credentials))
if err != nil {
return wrapSignalError(fmt.Errorf("create ICE gatherer: %w", err), ErrorCodeFailedToCreatePeerConnection)
}
var (
// Local candidates gathered by webrtc.ICEGatherer
candidates []webrtc.ICECandidate
// Notifies that gathering for local candidates has finished.
gatherFinished = make(chan struct{})
)
gatherer.OnLocalCandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
close(gatherFinished)
return
}
candidates = append(candidates, *candidate)
})
if err := gatherer.Gather(); err != nil {
return wrapSignalError(fmt.Errorf("gather local candidates: %w", err), ErrorCodeFailedToCreatePeerConnection)
}
select {
case <-ctx.Done():
return wrapSignalError(fmt.Errorf("gather local candidates: %w", err), ErrorCodeFailedToCreatePeerConnection)
case <-gatherFinished:
ice := l.conf.API.NewICETransport(gatherer)
dtls, err := l.conf.API.NewDTLSTransport(ice, nil)
if err != nil {
return wrapSignalError(fmt.Errorf("create DTLS transport: %w", err), ErrorCodeFailedToCreatePeerConnection)
}
sctp := l.conf.API.NewSCTPTransport(dtls)
iceParams, err := ice.GetLocalParameters()
if err != nil {
return wrapSignalError(fmt.Errorf("obtain local ICE parameters: %w", err), ErrorCodeFailedToCreateAnswer)
}
dtlsParams, err := dtls.GetLocalParameters()
if err != nil {
return wrapSignalError(fmt.Errorf("obtain local DTLS parameters: %w", err), ErrorCodeFailedToCreateAnswer)
}
if len(dtlsParams.Fingerprints) == 0 {
return wrapSignalError(errors.New("local DTLS parameters has no fingerprints"), ErrorCodeFailedToCreateAnswer)
}
sctpCapabilities := sctp.GetCapabilities()
// Encode an answer using the local parameters!
answer, err := description{
ice: iceParams,
dtls: dtlsParams,
sctp: sctpCapabilities,
}.encode()
if err != nil {
return wrapSignalError(fmt.Errorf("encode answer: %w", err), ErrorCodeFailedToCreateAnswer)
}
if err := l.signaling.Signal(&Signal{
Type: SignalTypeAnswer,
ConnectionID: signal.ConnectionID,
Data: string(answer),
NetworkID: signal.NetworkID,
}); err != nil {
// I don't think the error code will be signaled back to the remote connection, but just in case.
return wrapSignalError(fmt.Errorf("signal answer: %w", err), ErrorCodeSignalingFailedToSend)
}
for i, candidate := range candidates {
if err := l.signaling.Signal(&Signal{
Type: SignalTypeCandidate,
ConnectionID: signal.ConnectionID,
Data: formatICECandidate(i, candidate, iceParams),
NetworkID: signal.NetworkID,
}); err != nil {
// I don't think the error code will be signaled back to the remote connection, but just in case.
return wrapSignalError(fmt.Errorf("signal candidate: %w", err), ErrorCodeSignalingFailedToSend)
}
}
c := newConn(ice, dtls, sctp, signal.ConnectionID, signal.NetworkID, Addr{
NetworkID: l.networkID,
Candidates: candidates,
}, l)
l.connections.Store(c.remoteAddr().String(), c)
go l.handleConn(c, desc)
return nil
}
}
// handleSignal looks up for a Conn that matches the ConnectionID and NetworkID of the Signal.
// If a matching connection is found, it notifies the Signal by calling Conn.handleSignal. It
// is called by the default handler in listenerNotifier.NotifySignal when the Signal type does
// not match any specific handling cases.
func (l *Listener) handleSignal(signal *Signal) error {
addr := &Addr{
ConnectionID: signal.ConnectionID,
NetworkID: signal.NetworkID,
}
conn, ok := l.connections.Load(addr.String())
if !ok {
return fmt.Errorf("no connection found for %s", addr)
}
return conn.(*Conn).handleSignal(signal)
}
// handleClose deletes the Conn from the Listener, since it is closed and can no longer be negotiated.
func (l *Listener) handleClose(conn *Conn) {
l.connections.Delete(conn.remoteAddr().String())
}
// log extends the [slog.Logger] from [ListenConfig.Log] with an additional [slog.Attr] of "src" with the
// value "listener" to mark that the Conn has been negotiated by Listener, and returns it to be used as the logger
// of a Conn.
func (l *Listener) log() *slog.Logger {
return l.conf.Log.With(slog.String("src", "listener"))
}
// handleConn finalises the Conn. Once an ICE candidate for the Conn has been signaled from the remote
// connection, it starts the transports of the Conn using the remote description and a context.Context]
// returned from [ListenConfig.ConnContext].
func (l *Listener) handleConn(conn *Conn, d *description) {
var err error
defer func() {
if err != nil {
_ = conn.Close() // Stop notifying for the Conn.
if errors.Is(err, context.DeadlineExceeded) {
if err := l.signaling.Signal(&Signal{
Type: SignalTypeError,
ConnectionID: conn.id,
Data: strconv.Itoa(ErrorCodeNegotiationTimeoutWaitingForAccept),
NetworkID: conn.networkID,
}); err != nil {
conn.log.Error("error signaling timeout", internal.ErrAttr(err))
}
}
if !errors.Is(err, net.ErrClosed) {
conn.log.Error("error starting transports", internal.ErrAttr(err))
}
}
}()
var (
ctx context.Context
parent = listenerContext{closed: l.closed}
)
if l.conf.ConnContext != nil {
ctx = l.conf.ConnContext(parent, conn)
if ctx == nil {
panic("nethernet: ConnContext returned nil")
}
} else {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(parent, time.Second*5)
defer cancel()
}
select {
case <-ctx.Done():
err = ctx.Err()
case <-l.closed:
case <-conn.closed:
return
case <-conn.candidateReceived:
conn.log.Debug("received first candidate")
if err = l.startTransports(ctx, conn, d); err != nil {
return
}
conn.handleTransports()
l.incoming <- conn
}
}
// startTransports establishes ICE transport as [webrtc.ICERoleControlled], DTLS transport as [webrtc.DTLSRoleServer]
// and SCTP transport using the remote description. It will block until two data channels labeled 'ReliableDataChannel'
// and 'UnreliableDataChannel' are created by the remote connection. The [context.Context] is used to cancel blocking.
func (l *Listener) startTransports(ctx context.Context, conn *Conn, d *description) error {
conn.log.Debug("starting ICE transport as controlled")
iceRole := webrtc.ICERoleControlled
if err := withContext(ctx, func() error {
return conn.ice.Start(nil, d.ice, &iceRole)
}); err != nil {
return fmt.Errorf("start ICE: %w", err)
}
conn.log.Debug("starting DTLS transport as server")
if err := withContext(ctx, func() error {
return conn.dtls.Start(d.dtls)
}); err != nil {
return fmt.Errorf("start DTLS: %w", err)
}
conn.log.Debug("starting SCTP transport")
opened := make(chan struct{}, 1)
conn.sctp.OnDataChannelOpened(func(channel *webrtc.DataChannel) {
switch channel.Label() {
case "ReliableDataChannel":
conn.reliable = channel
case "UnreliableDataChannel":
conn.unreliable = channel
default:
return
}
if conn.reliable != nil && conn.unreliable != nil {
close(opened)
}
})
if err := withContext(ctx, func() error {
return conn.sctp.Start(d.sctp)
}); err != nil {
return fmt.Errorf("start SCTP: %w", err)
}
select {
case <-l.closed:
return net.ErrClosed
case <-opened:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// withContext calls the function with context-awareness (a little bit forcibly). It is useful for functions that
// do not accept any [context.Context] as a parameter, such as the Start method of each transport of the Conn (that
// will mostly hang if the remote connection does nothing).
func withContext(ctx context.Context, f func() error) error {
err := make(chan error, 1)
go func() {
err <- f()
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-err:
return err
}
}
// Close closes the Listener, ensuring that any blocking methods will return [net.ErrClosed] as an error.
func (l *Listener) Close() error {
select {
case <-l.closed:
return nil
default:
close(l.closed)
close(l.incoming)
l.stop()
return nil
}
}
// A signalError may be returned by the methods of Listener to handle incoming Signals signaled from the
// remote connection. The listenerNotifier may signal back with SignalTypeError to notify the error code
// occurred during handling a Signal.
type signalError struct {
// code is the code of the error occurred, it is one of constants defined in the below of SignalTypeError.
code int
underlying error
}
func (e *signalError) Error() string {
return fmt.Sprintf("nethernet: %s [signaling with code %d]", e.underlying, e.code)
}
// Unwrap returns the underlying error so that may be unwrapped with errors.Unwrap.
func (e *signalError) Unwrap() error { return e.underlying }
// wrapSignalError returns a signalError that includes the error as its underlying error (which may be
// unwrapped with [errors.Unwrap]) and the code to be signaled back to the remote connection. It is typically
// called by methods handling incoming Signals on the Listener.
func wrapSignalError(err error, code int) *signalError {
return &signalError{code: code, underlying: err}
}
// listenerContext implements [context.Context] for a Listener.
type listenerContext struct{ closed <-chan struct{} }
// Deadline returns the zero [time.Time] and false, indicating that deadlines are not used.
func (listenerContext) Deadline() (zero time.Time, ok bool) {
return zero, false
}
// Done returns a channel that is closed when the Listener has been closed.
func (ctx listenerContext) Done() <-chan struct{} {
return ctx.closed
}
// Err returns [net.ErrClosed] if the Listener has been closed. Returns nil otherwise.
func (ctx listenerContext) Err() error {
select {
case <-ctx.closed:
return net.ErrClosed
default:
return nil
}
}
// Value returns nil for any key, as no values are associated with the context.
func (listenerContext) Value(any) any {
return nil
}