Skip to content

Commit

Permalink
* limit max number of webrtc p2p connections
Browse files Browse the repository at this point in the history
* add webrtc thread pool support
  • Loading branch information
vyloy committed Nov 17, 2023
1 parent b5f5601 commit ee826f9
Show file tree
Hide file tree
Showing 24 changed files with 386 additions and 586 deletions.
5 changes: 5 additions & 0 deletions client/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ func (s *Server) Start() {
}

func (s *Server) exchangeWebRTCInfo(writer http.ResponseWriter, request *http.Request) {
defer func() {
if request.Body != nil {
request.Body.Close()
}
}()
value := request.Context().Value(rawConn)
if value == nil {
return
Expand Down
10 changes: 8 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func New(args []string, out io.Writer) (c *Client, err error) {
c.tunnelsCond = sync.NewCond(c.tunnelsRWMtx.RLocker())
c.apiServer = api.NewServer(l.With().Str("scope", "api").Logger())
c.apiServer.ReadTimeout = 30 * time.Second
c.webrtcThreadPool = webrtc.NewThreadPool(3)
return
}
func getDefaultConfig(args []string) (conf Config) {
Expand Down Expand Up @@ -475,6 +476,11 @@ func (c *Client) Start() (err error) {
} else if c.Config().RemoteIdleConnections > c.Config().RemoteConnections {
c.Config().RemoteIdleConnections = c.Config().RemoteConnections
}
if c.Config().WebRTCRemoteConnections < 1 {
c.Config().WebRTCRemoteConnections = 1
} else if c.Config().WebRTCRemoteConnections > 50 {
c.Config().WebRTCRemoteConnections = 50
}
c.idleManager = newIdleManager(c.Config().RemoteIdleConnections)

conf4Log := *c.Config()
Expand All @@ -491,8 +497,6 @@ func (c *Client) Start() (err error) {
// tcpforward
if c.Config().TCPForwardConnections < 1 {
c.Config().TCPForwardConnections = 1
} else if c.Config().TCPForwardConnections > 10 {
c.Config().TCPForwardConnections = 10
}
if c.Config().TCPForwardHostPrefix != "" {
c.tcpForwardListener, err = net.Listen("tcp", c.Config().TCPForwardAddr)
Expand Down Expand Up @@ -554,6 +558,7 @@ func (c *Client) Close() {
if c.tcpForwardListener != nil {
_ = c.tcpForwardListener.Close()
}
//c.webrtcThreadPool.Close()
}

// Shutdown stops the client gracefully.
Expand Down Expand Up @@ -587,6 +592,7 @@ func (c *Client) ShutdownWithoutClosingLogger() {
if c.tcpForwardListener != nil {
_ = c.tcpForwardListener.Close()
}
//c.webrtcThreadPool.Close()
}

func (c *Client) initConn(d dialer, connID uint) (result *conn, err error) {
Expand Down
4 changes: 3 additions & 1 deletion client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package client
import (
"encoding/json"
"fmt"
"gopkg.in/yaml.v3"
"net/url"
"strconv"
"strings"
Expand All @@ -26,6 +25,7 @@ import (
"github.com/isrc-cas/gt/config"
"github.com/isrc-cas/gt/predef"
"github.com/rs/zerolog"
"gopkg.in/yaml.v3"
)

// Config is a client config.
Expand Down Expand Up @@ -66,6 +66,7 @@ type Options struct {
SentryDebug bool `yaml:"sentryDebug,omitempty" json:",omitempty" usage:"Sentry debug mode, the debug information is printed to help you understand what sentry is doing"`

WebRTCConnectionIdleTimeout config.Duration `yaml:"webrtcConnectionIdleTimeout,omitempty" usage:"The timeout of WebRTC connection. Supports values like '30s', '5m'"`
WebRTCRemoteConnections uint `yaml:"webrtcConnections" usage:"The max number of webrtc connections. Valid value is 1 to 50"`
WebRTCLogLevel string `yaml:"webrtcLogLevel,omitempty" json:",omitempty" usage:"WebRTC log level: verbose, info, warning, error"`
WebRTCMinPort uint16 `yaml:"webrtcMinPort,omitempty" json:",omitempty" usage:"The min port of WebRTC peer connection"`
WebRTCMaxPort uint16 `yaml:"webrtcMaxPort,omitempty" json:",omitempty" usage:"The max port of WebRTC peer connection"`
Expand Down Expand Up @@ -108,6 +109,7 @@ func defaultConfig() Config {
SentryRelease: predef.Version,

WebRTCConnectionIdleTimeout: config.Duration{Duration: 5 * time.Minute},
WebRTCRemoteConnections: 30,
WebRTCLogLevel: "warning",

TCPForwardConnections: 3,
Expand Down
37 changes: 27 additions & 10 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,31 @@ func (c *conn) processData(taskID uint32, r *bufio.LimitedReader) (readErr, writ
}

func (c *conn) processP2P(id uint32, r *bufio.LimitedReader) {
t := &peerTask{}
t, ok := c.newPeerTask(id)
if !ok {
return
}

c.client.apiServer.Listener.AcceptCh() <- t.apiConn
t.Logger.Info().Msg("peer task started")
_, err := r.WriteTo(t.apiConn.PipeWriter)
if err != nil {
t.Logger.Error().Err(err).Msg("processP2P WriteTo failed")
}
}

func (c *conn) newPeerTask(id uint32) (t *peerTask, ok bool) {
c.client.peersRWMtx.Lock()
defer c.client.peersRWMtx.Unlock()
l := uint(len(c.client.peers))
if l >= c.client.Config().WebRTCRemoteConnections {
respAndClose(id, c, [][]byte{
[]byte("HTTP/1.1 403 Forbidden\r\nConnection: Closed\r\n\r\n"),
})
return
}

t = &peerTask{}
t.id = id
t.tunnel = c
t.apiConn = api.NewConn(id, "", c)
Expand All @@ -513,19 +537,12 @@ func (c *conn) processP2P(id uint32, r *bufio.LimitedReader) {
t.CloseWithLock()
})

c.client.peersRWMtx.Lock()
ot, ok := c.client.peers[id]
if ok && ot != nil {
ot.CloseWithLock()
ot.Logger.Info().Msg("got closed because task with same id is received")
}
c.client.peers[id] = t
c.client.peersRWMtx.Unlock()

c.client.apiServer.Listener.AcceptCh() <- t.apiConn
t.Logger.Info().Msg("peer task started")
_, err := r.WriteTo(t.apiConn.PipeWriter)
if err != nil {
t.Logger.Error().Err(err).Msg("processP2P WriteTo failed")
}
ok = true
return
}
3 changes: 2 additions & 1 deletion client/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

//go:build !release
// +build !release

package client

Expand All @@ -23,6 +22,7 @@ import (
"sync/atomic"

"github.com/isrc-cas/gt/client/api"
"github.com/isrc-cas/gt/client/webrtc"
"github.com/isrc-cas/gt/logger"
)

Expand All @@ -41,6 +41,7 @@ type Client struct {
apiServer *api.Server
services atomic.Pointer[services]
tcpForwardListener net.Listener
webrtcThreadPool *webrtc.ThreadPool
waitTunnelsShutdown sync.WaitGroup
configChecksum atomic.Pointer[[32]byte]
reloadWaitGroup sync.WaitGroup
Expand Down
5 changes: 4 additions & 1 deletion client/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ func (pt *peerTask) init(c *conn) (err error) {
OnICECandidate: pt.OnICECandidate,
OnICECandidateError: pt.OnICECandidateError,
}
err = webrtc.NewPeerConnection(&peerConnectionConfig, &pt.conn)
signalingThread := pt.tunnel.client.webrtcThreadPool.GetThread()
networkThread := pt.tunnel.client.webrtcThreadPool.GetSocketThread()
workerThread := pt.tunnel.client.webrtcThreadPool.GetThread()
err = webrtc.NewPeerConnection(&peerConnectionConfig, &pt.conn, signalingThread, networkThread, workerThread)
return
}

Expand Down
3 changes: 2 additions & 1 deletion client/release.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

//go:build release
// +build release

package client

Expand All @@ -23,6 +22,7 @@ import (
"sync/atomic"

"github.com/isrc-cas/gt/client/api"
"github.com/isrc-cas/gt/client/webrtc"
"github.com/isrc-cas/gt/logger"
)

Expand All @@ -41,6 +41,7 @@ type Client struct {
apiServer *api.Server
services atomic.Pointer[services]
tcpForwardListener net.Listener
webrtcThreadPool *webrtc.ThreadPool
waitTunnelsShutdown sync.WaitGroup
configChecksum atomic.Pointer[[32]byte]
reloadWaitGroup sync.WaitGroup
Expand Down
5 changes: 4 additions & 1 deletion client/tcpforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ func (c *Client) createPeerConnection(dialer dialer) (peerConnection *webrtc.Pee
OnICECandidateError: func(addrss string, port int, url string, errorCode int, errorText string) {
},
}
err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection)
signalingThread := c.webrtcThreadPool.GetThread()
networkThread := c.webrtcThreadPool.GetSocketThread()
workerThread := c.webrtcThreadPool.GetThread()
err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection, signalingThread, networkThread, workerThread)
if err != nil {
return
}
Expand Down
15 changes: 6 additions & 9 deletions client/webrtc/datachannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
#include "datachannel.h"
#include "datachannel.hpp"

using namespace std;
using namespace rtc;
using namespace webrtc;

::DataChannelObserver::DataChannelObserver(DataChannelInterface *dataChannel, void *userData)
::DataChannelObserver::DataChannelObserver(webrtc::DataChannelInterface *dataChannel,
void *userData)
: dataChannel(dataChannel), userData(userData) {}

::DataChannelObserver::~DataChannelObserver() {
Expand All @@ -34,7 +31,7 @@ void ::DataChannelObserver::OnStateChange() {
onDataChannelStateChange((int)dataChannel->state(), dataChannel->id(), (void *)this, userData);
}

void ::DataChannelObserver::OnMessage(const DataBuffer &buffer) {
void ::DataChannelObserver::OnMessage(const webrtc::DataBuffer &buffer) {
onDataChannelMessage((void *)buffer.data.data(), buffer.size(), (void *)this, userData);
}

Expand All @@ -50,12 +47,12 @@ void DeleteDataChannel(void *dataChannel) {
bool DataChannelSend(void *buf, int bufLen, void *dataChannel) {
auto dataChannelObserverInternal = (::DataChannelObserver *)dataChannel;
return dataChannelObserverInternal->dataChannel->Send(
DataBuffer(CopyOnWriteBuffer((char *)buf, (size_t)bufLen), true));
webrtc::DataBuffer(rtc::CopyOnWriteBuffer((char *)buf, (size_t)bufLen), true));
}

void SetDataChannelCallback(void *dataChannelWithoutCallback, void **dataChannelOutside,
void *userData) {
auto dataChannel = (DataChannelInterface *)dataChannelWithoutCallback;
auto dataChannel = (webrtc::DataChannelInterface *)dataChannelWithoutCallback;
auto dataChannelObserver = new ::DataChannelObserver(dataChannel, userData);
*dataChannelOutside = (void *)dataChannelObserver;
dataChannel->RegisterObserver(dataChannelObserver);
Expand Down Expand Up @@ -94,7 +91,7 @@ char *GetDataChannelError(void *DataChannel) {
auto dataChannelObserverInternal = (::DataChannelObserver *)DataChannel;
char *err = nullptr;
if (!dataChannelObserverInternal->dataChannel->error().ok()) {
stringstream ss;
std::stringstream ss;
ss << "type:'" << ToString(dataChannelObserverInternal->dataChannel->error().type())
<< "' message:'" << dataChannelObserverInternal->dataChannel->error().message()
<< "' error_detail:'"
Expand Down
12 changes: 8 additions & 4 deletions client/webrtc/datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (d *DataChannel) Send(message []byte) bool {

func (d *DataChannel) Close() {
if d.closed.CompareAndSwap(false, true) {
d.bufferedAmountChangeCond.Signal()
d.bufferedAmountChangeCond.Broadcast()
C.DeleteDataChannel(d.dataChannel)
pointer.Unref(d.pointerID)
}
Expand Down Expand Up @@ -186,9 +186,9 @@ func (d *DataChannel) State() DataState {

func (d *DataChannel) Error() string {
errorC := C.GetDataChannelError(d.dataChannel)
err := C.GoString(errorC)
error := C.GoString(errorC)
C.free(unsafe.Pointer(errorC))
return err
return error
}

func (d *DataChannel) MessageSent() uint32 {
Expand Down Expand Up @@ -231,5 +231,9 @@ func (d *DataChannelWithoutCallback) SetCallback(config *DataChannelConfig, data
channel.bufferedAmountChangeCond = sync.NewCond(&channel.bufferedAmountMtx)
*dataChannelPointer = channel
(*dataChannelPointer).pointerID = pointer.Save(*dataChannelPointer)
C.SetDataChannelCallback(d.pointer, &(*dataChannelPointer).dataChannel, (*dataChannelPointer).pointerID)
C.SetDataChannelCallback(
d.pointer,
&(*dataChannelPointer).dataChannel,
(*dataChannelPointer).pointerID,
)
}
Loading

0 comments on commit ee826f9

Please sign in to comment.