-
Notifications
You must be signed in to change notification settings - Fork 165
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Steffen Vogel <[email protected]> Co-authored-by: Artur Shellunts <[email protected]>
- Loading branch information
1 parent
9f4e3d6
commit db5d7ea
Showing
8 changed files
with
547 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> | ||
// SPDX-License-Identifier: MIT | ||
|
||
package ice | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"net" | ||
"sync/atomic" | ||
"time" | ||
|
||
"github.com/pion/logging" | ||
"github.com/pion/transport/v2/packetio" | ||
) | ||
|
||
type activeTCPConn struct { | ||
readBuffer, writeBuffer *packetio.Buffer | ||
localAddr, remoteAddr atomic.Value | ||
closed int32 | ||
} | ||
|
||
func newActiveTCPConn(ctx context.Context, localAddress, remoteAddress string, log logging.LeveledLogger) (a *activeTCPConn) { | ||
a = &activeTCPConn{ | ||
readBuffer: packetio.NewBuffer(), | ||
writeBuffer: packetio.NewBuffer(), | ||
} | ||
|
||
laddr, err := getTCPAddrOnInterface(localAddress) | ||
if err != nil { | ||
atomic.StoreInt32(&a.closed, 1) | ||
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) | ||
return | ||
} | ||
a.localAddr.Store(laddr) | ||
|
||
go func() { | ||
defer func() { | ||
atomic.StoreInt32(&a.closed, 1) | ||
}() | ||
|
||
dialer := &net.Dialer{ | ||
LocalAddr: laddr, | ||
} | ||
conn, err := dialer.DialContext(ctx, "tcp", remoteAddress) | ||
if err != nil { | ||
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) | ||
return | ||
} | ||
|
||
a.remoteAddr.Store(conn.RemoteAddr()) | ||
|
||
go func() { | ||
buff := make([]byte, receiveMTU) | ||
|
||
for atomic.LoadInt32(&a.closed) == 0 { | ||
n, err := readStreamingPacket(conn, buff) | ||
if err != nil { | ||
log.Infof("%v: %s", errReadingStreamingPacket, err) | ||
break | ||
} | ||
|
||
if _, err := a.readBuffer.Write(buff[:n]); err != nil { | ||
log.Infof("%v: %s", errReadingStreamingPacket, err) | ||
break | ||
} | ||
} | ||
}() | ||
|
||
buff := make([]byte, receiveMTU) | ||
|
||
for atomic.LoadInt32(&a.closed) == 0 { | ||
n, err := a.writeBuffer.Read(buff) | ||
if err != nil { | ||
log.Infof("%v: %s", errReadingStreamingPacket, err) | ||
break | ||
} | ||
|
||
if _, err = writeStreamingPacket(conn, buff[:n]); err != nil { | ||
log.Infof("%v: %s", errReadingStreamingPacket, err) | ||
break | ||
} | ||
} | ||
|
||
if err := conn.Close(); err != nil { | ||
log.Infof("%v: %s", errReadingStreamingPacket, err) | ||
} | ||
}() | ||
|
||
return a | ||
} | ||
|
||
func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err error) { | ||
if atomic.LoadInt32(&a.closed) == 1 { | ||
return 0, nil, io.ErrClosedPipe | ||
} | ||
|
||
srcAddr = a.RemoteAddr() | ||
n, err = a.readBuffer.Read(buff) | ||
return | ||
} | ||
|
||
func (a *activeTCPConn) WriteTo(buff []byte, _ net.Addr) (n int, err error) { | ||
if atomic.LoadInt32(&a.closed) == 1 { | ||
return 0, io.ErrClosedPipe | ||
} | ||
|
||
return a.writeBuffer.Write(buff) | ||
} | ||
|
||
func (a *activeTCPConn) Close() error { | ||
atomic.StoreInt32(&a.closed, 1) | ||
_ = a.readBuffer.Close() | ||
_ = a.writeBuffer.Close() | ||
return nil | ||
} | ||
|
||
func (a *activeTCPConn) LocalAddr() net.Addr { | ||
if v, ok := a.localAddr.Load().(*net.TCPAddr); ok { | ||
return v | ||
} | ||
|
||
return &net.TCPAddr{} | ||
} | ||
|
||
func (a *activeTCPConn) RemoteAddr() net.Addr { | ||
if v, ok := a.remoteAddr.Load().(*net.TCPAddr); ok { | ||
return v | ||
} | ||
|
||
return &net.TCPAddr{} | ||
} | ||
|
||
func (a *activeTCPConn) SetDeadline(time.Time) error { return io.EOF } | ||
func (a *activeTCPConn) SetReadDeadline(time.Time) error { return io.EOF } | ||
func (a *activeTCPConn) SetWriteDeadline(time.Time) error { return io.EOF } | ||
|
||
func getTCPAddrOnInterface(address string) (*net.TCPAddr, error) { | ||
addr, err := net.ResolveTCPAddr("tcp", address) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
l, err := net.ListenTCP("tcp", addr) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer func() { | ||
_ = l.Close() | ||
}() | ||
|
||
tcpAddr, ok := l.Addr().(*net.TCPAddr) | ||
if !ok { | ||
return nil, errInvalidAddress | ||
} | ||
|
||
return tcpAddr, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> | ||
// SPDX-License-Identifier: MIT | ||
|
||
//go:build !js | ||
// +build !js | ||
|
||
package ice | ||
|
||
import ( | ||
"net" | ||
"testing" | ||
"time" | ||
|
||
"github.com/pion/logging" | ||
"github.com/pion/transport/v2/stdnet" | ||
"github.com/pion/transport/v2/test" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func getLocalIPAddress(t *testing.T, networkType NetworkType) net.IP { | ||
net, err := stdnet.NewNet() | ||
require.NoError(t, err) | ||
localIPs, err := localInterfaces(net, nil, nil, []NetworkType{networkType}, false) | ||
require.NoError(t, err) | ||
require.NotEmpty(t, localIPs) | ||
return localIPs[0] | ||
} | ||
|
||
func ipv6Available(t *testing.T) bool { | ||
net, err := stdnet.NewNet() | ||
require.NoError(t, err) | ||
localIPs, err := localInterfaces(net, nil, nil, []NetworkType{NetworkTypeTCP6}, false) | ||
require.NoError(t, err) | ||
return len(localIPs) > 0 | ||
} | ||
|
||
func TestActiveTCP(t *testing.T) { | ||
report := test.CheckRoutines(t) | ||
defer report() | ||
|
||
lim := test.TimeOut(time.Second * 5) | ||
defer lim.Stop() | ||
|
||
const listenPort = 7686 | ||
type testCase struct { | ||
name string | ||
networkTypes []NetworkType | ||
listenIPAddress net.IP | ||
selectedPairNetworkType string | ||
} | ||
|
||
testCases := []testCase{ | ||
{ | ||
name: "TCP4 connection", | ||
networkTypes: []NetworkType{NetworkTypeTCP4}, | ||
listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP4), | ||
selectedPairNetworkType: tcp, | ||
}, | ||
{ | ||
name: "UDP is preferred over TCP4", // This fails some time | ||
networkTypes: supportedNetworkTypes(), | ||
listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP4), | ||
selectedPairNetworkType: udp, | ||
}, | ||
} | ||
|
||
if ipv6Available(t) { | ||
testCases = append(testCases, | ||
testCase{ | ||
name: "TCP6 connection", | ||
networkTypes: []NetworkType{NetworkTypeTCP6}, | ||
listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), | ||
selectedPairNetworkType: tcp, | ||
}, | ||
testCase{ | ||
name: "UDP is preferred over TCP6", // This fails some time | ||
networkTypes: supportedNetworkTypes(), | ||
listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), | ||
selectedPairNetworkType: udp, | ||
}, | ||
) | ||
} | ||
|
||
for _, testCase := range testCases { | ||
t.Run(testCase.name, func(t *testing.T) { | ||
r := require.New(t) | ||
|
||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{ | ||
IP: testCase.listenIPAddress, | ||
Port: listenPort, | ||
}) | ||
r.NoError(err) | ||
defer func() { | ||
_ = listener.Close() | ||
}() | ||
|
||
loggerFactory := logging.NewDefaultLoggerFactory() | ||
|
||
tcpMux := NewTCPMuxDefault(TCPMuxParams{ | ||
Listener: listener, | ||
Logger: loggerFactory.NewLogger("passive-ice-tcp-mux"), | ||
ReadBufferSize: 20, | ||
}) | ||
|
||
defer func() { | ||
_ = tcpMux.Close() | ||
}() | ||
|
||
r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") | ||
|
||
hostAcceptanceMinWait := 100 * time.Millisecond | ||
passiveAgent, err := NewAgent(&AgentConfig{ | ||
TCPMux: tcpMux, | ||
CandidateTypes: []CandidateType{CandidateTypeHost}, | ||
NetworkTypes: testCase.networkTypes, | ||
LoggerFactory: loggerFactory, | ||
IncludeLoopback: true, | ||
HostAcceptanceMinWait: &hostAcceptanceMinWait, | ||
}) | ||
r.NoError(err) | ||
r.NotNil(passiveAgent) | ||
|
||
activeAgent, err := NewAgent(&AgentConfig{ | ||
CandidateTypes: []CandidateType{CandidateTypeHost}, | ||
NetworkTypes: testCase.networkTypes, | ||
LoggerFactory: loggerFactory, | ||
HostAcceptanceMinWait: &hostAcceptanceMinWait, | ||
}) | ||
r.NoError(err) | ||
r.NotNil(activeAgent) | ||
|
||
passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent) | ||
r.NotNil(passiveAgentConn) | ||
r.NotNil(activeAgenConn) | ||
|
||
pair := passiveAgent.getSelectedPair() | ||
r.NotNil(pair) | ||
r.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort()) | ||
|
||
foo := []byte("foo") | ||
_, err = passiveAgentConn.Write(foo) | ||
r.NoError(err) | ||
|
||
buffer := make([]byte, 1024) | ||
n, err := activeAgenConn.Read(buffer) | ||
r.NoError(err) | ||
r.Equal(foo, buffer[:n]) | ||
|
||
bar := []byte("bar") | ||
_, err = activeAgenConn.Write(bar) | ||
r.NoError(err) | ||
|
||
n, err = passiveAgentConn.Read(buffer) | ||
r.NoError(err) | ||
r.Equal(bar, buffer[:n]) | ||
|
||
r.NoError(activeAgenConn.Close()) | ||
r.NoError(passiveAgentConn.Close()) | ||
}) | ||
} | ||
} | ||
|
||
// Assert that Active TCP connectivity isn't established inside | ||
// the main thread of the Agent | ||
func TestActiveTCP_NonBlocking(t *testing.T) { | ||
report := test.CheckRoutines(t) | ||
defer report() | ||
|
||
lim := test.TimeOut(time.Second * 5) | ||
defer lim.Stop() | ||
|
||
cfg := &AgentConfig{ | ||
NetworkTypes: supportedNetworkTypes(), | ||
} | ||
|
||
aAgent, err := NewAgent(cfg) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
||
bAgent, err := NewAgent(cfg) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
||
isConnected := make(chan interface{}) | ||
err = aAgent.OnConnectionStateChange(func(c ConnectionState) { | ||
if c == ConnectionStateConnected { | ||
close(isConnected) | ||
} | ||
}) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
||
// Add a invalid ice-tcp candidate to each | ||
invalidCandidate, err := UnmarshalCandidate("1052353102 1 tcp 1675624447 192.0.2.1 8080 typ host tcptype passive") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
assert.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate)) | ||
assert.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate)) | ||
|
||
connect(aAgent, bAgent) | ||
|
||
<-isConnected | ||
assert.NoError(t, aAgent.Close()) | ||
assert.NoError(t, bAgent.Close()) | ||
} |
Oops, something went wrong.