From ee826f9d9fc136262befbb37ff3b60608dc5957f Mon Sep 17 00:00:00 2001 From: zhiyi Date: Thu, 16 Nov 2023 14:39:24 +0800 Subject: [PATCH] * limit max number of webrtc p2p connections * add webrtc thread pool support --- client/api/server.go | 5 + client/client.go | 10 +- client/config.go | 4 +- client/conn.go | 37 +++- client/debug.go | 3 +- client/peer.go | 5 +- client/release.go | 3 +- client/tcpforward.go | 5 +- client/webrtc/datachannel.cpp | 15 +- client/webrtc/datachannel.go | 12 +- client/webrtc/integrate_test.go | 208 --------------------- client/webrtc/logging.cpp | 31 ++- client/webrtc/logging.go | 5 +- client/webrtc/peerconnection.cpp | 138 +++++++------- client/webrtc/peerconnection.go | 45 ++++- client/webrtc/peerconnection.h | 3 +- client/webrtc/speedtest/client/main.go | 136 -------------- client/webrtc/speedtest/echoServer/main.go | 115 ------------ client/webrtc/threadpool.cpp | 105 +++++++++++ client/webrtc/threadpool.go | 41 ++++ client/webrtc/threadpool.h | 33 ++++ predef/debug.go | 3 +- predef/release.go | 6 - test/p2p_test.go | 4 +- 24 files changed, 386 insertions(+), 586 deletions(-) delete mode 100644 client/webrtc/integrate_test.go delete mode 100644 client/webrtc/speedtest/client/main.go delete mode 100644 client/webrtc/speedtest/echoServer/main.go create mode 100644 client/webrtc/threadpool.cpp create mode 100644 client/webrtc/threadpool.go create mode 100644 client/webrtc/threadpool.h diff --git a/client/api/server.go b/client/api/server.go index cdbb5019..4125b5dc 100644 --- a/client/api/server.go +++ b/client/api/server.go @@ -62,6 +62,11 @@ func (s *Server) Start() { } func (s *Server) exchangeWebRTCInfo(writer http.ResponseWriter, request *http.Request) { + defer func() { + if request.Body != nil { + request.Body.Close() + } + }() value := request.Context().Value(rawConn) if value == nil { return diff --git a/client/client.go b/client/client.go index cbf2b69b..201ea4c5 100644 --- a/client/client.go +++ b/client/client.go @@ -95,6 +95,7 @@ func New(args []string, out io.Writer) (c *Client, err error) { c.tunnelsCond = sync.NewCond(c.tunnelsRWMtx.RLocker()) c.apiServer = api.NewServer(l.With().Str("scope", "api").Logger()) c.apiServer.ReadTimeout = 30 * time.Second + c.webrtcThreadPool = webrtc.NewThreadPool(3) return } func getDefaultConfig(args []string) (conf Config) { @@ -475,6 +476,11 @@ func (c *Client) Start() (err error) { } else if c.Config().RemoteIdleConnections > c.Config().RemoteConnections { c.Config().RemoteIdleConnections = c.Config().RemoteConnections } + if c.Config().WebRTCRemoteConnections < 1 { + c.Config().WebRTCRemoteConnections = 1 + } else if c.Config().WebRTCRemoteConnections > 50 { + c.Config().WebRTCRemoteConnections = 50 + } c.idleManager = newIdleManager(c.Config().RemoteIdleConnections) conf4Log := *c.Config() @@ -491,8 +497,6 @@ func (c *Client) Start() (err error) { // tcpforward if c.Config().TCPForwardConnections < 1 { c.Config().TCPForwardConnections = 1 - } else if c.Config().TCPForwardConnections > 10 { - c.Config().TCPForwardConnections = 10 } if c.Config().TCPForwardHostPrefix != "" { c.tcpForwardListener, err = net.Listen("tcp", c.Config().TCPForwardAddr) @@ -554,6 +558,7 @@ func (c *Client) Close() { if c.tcpForwardListener != nil { _ = c.tcpForwardListener.Close() } + //c.webrtcThreadPool.Close() } // Shutdown stops the client gracefully. @@ -587,6 +592,7 @@ func (c *Client) ShutdownWithoutClosingLogger() { if c.tcpForwardListener != nil { _ = c.tcpForwardListener.Close() } + //c.webrtcThreadPool.Close() } func (c *Client) initConn(d dialer, connID uint) (result *conn, err error) { diff --git a/client/config.go b/client/config.go index a59dc4ea..bb613db1 100644 --- a/client/config.go +++ b/client/config.go @@ -17,7 +17,6 @@ package client import ( "encoding/json" "fmt" - "gopkg.in/yaml.v3" "net/url" "strconv" "strings" @@ -26,6 +25,7 @@ import ( "github.com/isrc-cas/gt/config" "github.com/isrc-cas/gt/predef" "github.com/rs/zerolog" + "gopkg.in/yaml.v3" ) // Config is a client config. @@ -66,6 +66,7 @@ type Options struct { SentryDebug bool `yaml:"sentryDebug,omitempty" json:",omitempty" usage:"Sentry debug mode, the debug information is printed to help you understand what sentry is doing"` WebRTCConnectionIdleTimeout config.Duration `yaml:"webrtcConnectionIdleTimeout,omitempty" usage:"The timeout of WebRTC connection. Supports values like '30s', '5m'"` + WebRTCRemoteConnections uint `yaml:"webrtcConnections" usage:"The max number of webrtc connections. Valid value is 1 to 50"` WebRTCLogLevel string `yaml:"webrtcLogLevel,omitempty" json:",omitempty" usage:"WebRTC log level: verbose, info, warning, error"` WebRTCMinPort uint16 `yaml:"webrtcMinPort,omitempty" json:",omitempty" usage:"The min port of WebRTC peer connection"` WebRTCMaxPort uint16 `yaml:"webrtcMaxPort,omitempty" json:",omitempty" usage:"The max port of WebRTC peer connection"` @@ -108,6 +109,7 @@ func defaultConfig() Config { SentryRelease: predef.Version, WebRTCConnectionIdleTimeout: config.Duration{Duration: 5 * time.Minute}, + WebRTCRemoteConnections: 30, WebRTCLogLevel: "warning", TCPForwardConnections: 3, diff --git a/client/conn.go b/client/conn.go index b86c782c..670f0146 100644 --- a/client/conn.go +++ b/client/conn.go @@ -494,7 +494,31 @@ func (c *conn) processData(taskID uint32, r *bufio.LimitedReader) (readErr, writ } func (c *conn) processP2P(id uint32, r *bufio.LimitedReader) { - t := &peerTask{} + t, ok := c.newPeerTask(id) + if !ok { + return + } + + c.client.apiServer.Listener.AcceptCh() <- t.apiConn + t.Logger.Info().Msg("peer task started") + _, err := r.WriteTo(t.apiConn.PipeWriter) + if err != nil { + t.Logger.Error().Err(err).Msg("processP2P WriteTo failed") + } +} + +func (c *conn) newPeerTask(id uint32) (t *peerTask, ok bool) { + c.client.peersRWMtx.Lock() + defer c.client.peersRWMtx.Unlock() + l := uint(len(c.client.peers)) + if l >= c.client.Config().WebRTCRemoteConnections { + respAndClose(id, c, [][]byte{ + []byte("HTTP/1.1 403 Forbidden\r\nConnection: Closed\r\n\r\n"), + }) + return + } + + t = &peerTask{} t.id = id t.tunnel = c t.apiConn = api.NewConn(id, "", c) @@ -513,19 +537,12 @@ func (c *conn) processP2P(id uint32, r *bufio.LimitedReader) { t.CloseWithLock() }) - c.client.peersRWMtx.Lock() ot, ok := c.client.peers[id] if ok && ot != nil { ot.CloseWithLock() ot.Logger.Info().Msg("got closed because task with same id is received") } c.client.peers[id] = t - c.client.peersRWMtx.Unlock() - - c.client.apiServer.Listener.AcceptCh() <- t.apiConn - t.Logger.Info().Msg("peer task started") - _, err := r.WriteTo(t.apiConn.PipeWriter) - if err != nil { - t.Logger.Error().Err(err).Msg("processP2P WriteTo failed") - } + ok = true + return } diff --git a/client/debug.go b/client/debug.go index f1c2075b..9b8ca9d9 100644 --- a/client/debug.go +++ b/client/debug.go @@ -13,7 +13,6 @@ // limitations under the License. //go:build !release -// +build !release package client @@ -23,6 +22,7 @@ import ( "sync/atomic" "github.com/isrc-cas/gt/client/api" + "github.com/isrc-cas/gt/client/webrtc" "github.com/isrc-cas/gt/logger" ) @@ -41,6 +41,7 @@ type Client struct { apiServer *api.Server services atomic.Pointer[services] tcpForwardListener net.Listener + webrtcThreadPool *webrtc.ThreadPool waitTunnelsShutdown sync.WaitGroup configChecksum atomic.Pointer[[32]byte] reloadWaitGroup sync.WaitGroup diff --git a/client/peer.go b/client/peer.go index ef5b775d..3ef2fa30 100644 --- a/client/peer.go +++ b/client/peer.go @@ -134,7 +134,10 @@ func (pt *peerTask) init(c *conn) (err error) { OnICECandidate: pt.OnICECandidate, OnICECandidateError: pt.OnICECandidateError, } - err = webrtc.NewPeerConnection(&peerConnectionConfig, &pt.conn) + signalingThread := pt.tunnel.client.webrtcThreadPool.GetThread() + networkThread := pt.tunnel.client.webrtcThreadPool.GetSocketThread() + workerThread := pt.tunnel.client.webrtcThreadPool.GetThread() + err = webrtc.NewPeerConnection(&peerConnectionConfig, &pt.conn, signalingThread, networkThread, workerThread) return } diff --git a/client/release.go b/client/release.go index ecc0537a..62b55ea1 100644 --- a/client/release.go +++ b/client/release.go @@ -13,7 +13,6 @@ // limitations under the License. //go:build release -// +build release package client @@ -23,6 +22,7 @@ import ( "sync/atomic" "github.com/isrc-cas/gt/client/api" + "github.com/isrc-cas/gt/client/webrtc" "github.com/isrc-cas/gt/logger" ) @@ -41,6 +41,7 @@ type Client struct { apiServer *api.Server services atomic.Pointer[services] tcpForwardListener net.Listener + webrtcThreadPool *webrtc.ThreadPool waitTunnelsShutdown sync.WaitGroup configChecksum atomic.Pointer[[32]byte] reloadWaitGroup sync.WaitGroup diff --git a/client/tcpforward.go b/client/tcpforward.go index d16a3852..4e4fb4fe 100644 --- a/client/tcpforward.go +++ b/client/tcpforward.go @@ -204,7 +204,10 @@ func (c *Client) createPeerConnection(dialer dialer) (peerConnection *webrtc.Pee OnICECandidateError: func(addrss string, port int, url string, errorCode int, errorText string) { }, } - err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection) + signalingThread := c.webrtcThreadPool.GetThread() + networkThread := c.webrtcThreadPool.GetSocketThread() + workerThread := c.webrtcThreadPool.GetThread() + err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection, signalingThread, networkThread, workerThread) if err != nil { return } diff --git a/client/webrtc/datachannel.cpp b/client/webrtc/datachannel.cpp index 75e40d70..8dd67d91 100644 --- a/client/webrtc/datachannel.cpp +++ b/client/webrtc/datachannel.cpp @@ -17,11 +17,8 @@ #include "datachannel.h" #include "datachannel.hpp" -using namespace std; -using namespace rtc; -using namespace webrtc; - -::DataChannelObserver::DataChannelObserver(DataChannelInterface *dataChannel, void *userData) +::DataChannelObserver::DataChannelObserver(webrtc::DataChannelInterface *dataChannel, + void *userData) : dataChannel(dataChannel), userData(userData) {} ::DataChannelObserver::~DataChannelObserver() { @@ -34,7 +31,7 @@ void ::DataChannelObserver::OnStateChange() { onDataChannelStateChange((int)dataChannel->state(), dataChannel->id(), (void *)this, userData); } -void ::DataChannelObserver::OnMessage(const DataBuffer &buffer) { +void ::DataChannelObserver::OnMessage(const webrtc::DataBuffer &buffer) { onDataChannelMessage((void *)buffer.data.data(), buffer.size(), (void *)this, userData); } @@ -50,12 +47,12 @@ void DeleteDataChannel(void *dataChannel) { bool DataChannelSend(void *buf, int bufLen, void *dataChannel) { auto dataChannelObserverInternal = (::DataChannelObserver *)dataChannel; return dataChannelObserverInternal->dataChannel->Send( - DataBuffer(CopyOnWriteBuffer((char *)buf, (size_t)bufLen), true)); + webrtc::DataBuffer(rtc::CopyOnWriteBuffer((char *)buf, (size_t)bufLen), true)); } void SetDataChannelCallback(void *dataChannelWithoutCallback, void **dataChannelOutside, void *userData) { - auto dataChannel = (DataChannelInterface *)dataChannelWithoutCallback; + auto dataChannel = (webrtc::DataChannelInterface *)dataChannelWithoutCallback; auto dataChannelObserver = new ::DataChannelObserver(dataChannel, userData); *dataChannelOutside = (void *)dataChannelObserver; dataChannel->RegisterObserver(dataChannelObserver); @@ -94,7 +91,7 @@ char *GetDataChannelError(void *DataChannel) { auto dataChannelObserverInternal = (::DataChannelObserver *)DataChannel; char *err = nullptr; if (!dataChannelObserverInternal->dataChannel->error().ok()) { - stringstream ss; + std::stringstream ss; ss << "type:'" << ToString(dataChannelObserverInternal->dataChannel->error().type()) << "' message:'" << dataChannelObserverInternal->dataChannel->error().message() << "' error_detail:'" diff --git a/client/webrtc/datachannel.go b/client/webrtc/datachannel.go index 89995715..ac661c49 100644 --- a/client/webrtc/datachannel.go +++ b/client/webrtc/datachannel.go @@ -155,7 +155,7 @@ func (d *DataChannel) Send(message []byte) bool { func (d *DataChannel) Close() { if d.closed.CompareAndSwap(false, true) { - d.bufferedAmountChangeCond.Signal() + d.bufferedAmountChangeCond.Broadcast() C.DeleteDataChannel(d.dataChannel) pointer.Unref(d.pointerID) } @@ -186,9 +186,9 @@ func (d *DataChannel) State() DataState { func (d *DataChannel) Error() string { errorC := C.GetDataChannelError(d.dataChannel) - err := C.GoString(errorC) + error := C.GoString(errorC) C.free(unsafe.Pointer(errorC)) - return err + return error } func (d *DataChannel) MessageSent() uint32 { @@ -231,5 +231,9 @@ func (d *DataChannelWithoutCallback) SetCallback(config *DataChannelConfig, data channel.bufferedAmountChangeCond = sync.NewCond(&channel.bufferedAmountMtx) *dataChannelPointer = channel (*dataChannelPointer).pointerID = pointer.Save(*dataChannelPointer) - C.SetDataChannelCallback(d.pointer, &(*dataChannelPointer).dataChannel, (*dataChannelPointer).pointerID) + C.SetDataChannelCallback( + d.pointer, + &(*dataChannelPointer).dataChannel, + (*dataChannelPointer).pointerID, + ) } diff --git a/client/webrtc/integrate_test.go b/client/webrtc/integrate_test.go deleted file mode 100644 index d87ca787..00000000 --- a/client/webrtc/integrate_test.go +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright (c) 2022 Institute of Software, Chinese Academy of Sciences (ISCAS) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package webrtc_test - -import ( - "fmt" - "testing" - - "github.com/isrc-cas/gt/client/webrtc" -) - -func TestWebRTC(t *testing.T) { - webrtc.SetLog(webrtc.LoggingSeverityWarning, func(severity webrtc.LoggingSeverity, message, tag string) { - fmt.Println("severity", severity.String(), "message", message, "tag", tag) - }) - offerChan := make(chan *webrtc.SessionDescription) - answerChan := make(chan *webrtc.SessionDescription) - exitChan := make(chan struct{}) - go client(offerChan, answerChan) - go server(offerChan, answerChan, exitChan) - <-exitChan -} - -func client(offerChan, answerChan chan *webrtc.SessionDescription) { - waitNegotiationNeeded := make(chan struct{}) - var peerConnection *webrtc.PeerConnection - var err error - peerConnectionConfig := webrtc.PeerConnectionConfig{ - ICEServers: []string{"stun:stun.l.google.com:19302"}, - OnSignalingChange: func(state webrtc.SignalingState) { - fmt.Println(state.String()) - }, - OnDataChannel: func(dataChannel *webrtc.DataChannelWithoutCallback) { - }, - OnRenegotiationNeeded: func() { - }, - OnNegotiationNeeded: func() { - close(waitNegotiationNeeded) - }, - OnICEConnectionChange: func(state webrtc.ICEConnectionState) { - }, - OnStandardizedICEConnectionChange: func(state webrtc.ICEConnectionState) { - }, - OnConnectionChange: func(state webrtc.PeerConnectionState) { - fmt.Println("peer connection", state.String()) - }, - OnICEGatheringChange: func(state webrtc.ICEGatheringState) { - fmt.Println("ice gathering", state.String()) - }, - OnICECandidate: func(iceCandidate *webrtc.ICECandidate) { - fmt.Printf("get ice candidate:'%#v'\n", iceCandidate) - }, - OnICECandidateError: func(addrss string, port int, url string, errorCode int, errorText string) { - }, - } - err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection) - if err != nil { - panic(err) - } - - { - var dataChannel *webrtc.DataChannel - dataChannelConfig := webrtc.DataChannelConfig{ - OnStateChange: func(state webrtc.DataState) { - fmt.Println("data channel", state.String()) - if state == webrtc.DataStateOpen { - if !dataChannel.Send([]byte("test data channel send")) { - panic("data channel send failed") - } - } - }, - OnMessage: func(message []byte) { - }, - } - err = peerConnection.CreateDataChannel("test 1", false, &dataChannelConfig, &dataChannel) - if err != nil { - panic(err) - } - } - { - var dataChannel *webrtc.DataChannel - dataChannelConfig := webrtc.DataChannelConfig{ - OnStateChange: func(state webrtc.DataState) { - fmt.Println("data channel", state.String()) - if state == webrtc.DataStateOpen { - if !dataChannel.Send([]byte("test data channel send")) { - panic("data channel send failed") - } - } - }, - OnMessage: func(message []byte) { - }, - } - err = peerConnection.CreateDataChannel("test 2", false, &dataChannelConfig, &dataChannel) - if err != nil { - panic(err) - } - } - - <-waitNegotiationNeeded - offer, err := peerConnection.CreateOffer() - if err != nil { - panic(err) - } - err = peerConnection.SetLocalDescription(offer) - if err != nil { - panic(err) - } - fmt.Printf("send offer:'%#v'\n", offer) - offerChan <- offer - - answer := <-answerChan - err = peerConnection.SetRemoteDescription(answer) - if err != nil { - panic(err) - } - fmt.Printf("receive answer:'%#v'\n", offer) -} - -func server(offerChan, answerChan chan *webrtc.SessionDescription, exitChan chan struct{}) { - waitICEGatheringComplete := make(chan struct{}) - var peerConnection *webrtc.PeerConnection - var err error - peerConnectionConfig := webrtc.PeerConnectionConfig{ - ICEServers: []string{"stun:stun.l.google.com:19302"}, - OnSignalingChange: func(state webrtc.SignalingState) { - fmt.Println("signaling", state.String()) - }, - OnDataChannel: func(dataChannelWithoutCallback *webrtc.DataChannelWithoutCallback) { - var dataChannel *webrtc.DataChannel - dataChannelConfig := webrtc.DataChannelConfig{ - OnStateChange: func(state webrtc.DataState) { - fmt.Println("data channel", state.String()) - if state == webrtc.DataStateOpen { - if !dataChannel.Send([]byte("test data channel send")) { - panic("data channel send failed") - } - } - }, - OnMessage: func(message []byte) { - fmt.Println("server recieve data channel message:", string(message)) - if dataChannel.Label == "test 2" { - close(exitChan) - } - }, - } - dataChannelWithoutCallback.SetCallback(&dataChannelConfig, &dataChannel) - }, - OnRenegotiationNeeded: func() { - }, - OnNegotiationNeeded: func() { - }, - OnICEConnectionChange: func(state webrtc.ICEConnectionState) { - }, - OnStandardizedICEConnectionChange: func(state webrtc.ICEConnectionState) { - }, - OnConnectionChange: func(state webrtc.PeerConnectionState) { - fmt.Println("peer connection", state.String()) - }, - OnICEGatheringChange: func(state webrtc.ICEGatheringState) { - fmt.Println("ice gathering", state.String()) - if state == webrtc.ICEGatheringStateComplete { - close(waitICEGatheringComplete) - } - }, - OnICECandidate: func(iceCandidate *webrtc.ICECandidate) { - fmt.Printf("get ice candidate:'%#v'\n", iceCandidate) - }, - OnICECandidateError: func(addrss string, port int, url string, errorCode int, errorText string) { - }, - } - err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection) - if err != nil { - panic(err) - } - - offer := <-offerChan - fmt.Printf("receive offer:'%#v'\n", offer) - err = peerConnection.SetRemoteDescription(offer) - if err != nil { - panic(err) - } - - answer, err := peerConnection.CreateAnswer() - if err != nil { - panic(err) - } - err = peerConnection.SetLocalDescription(answer) - if err != nil { - panic(err) - } - <-waitICEGatheringComplete - answer = peerConnection.GetLocalDescription() - answerChan <- answer - fmt.Printf("send answer:'%#v'\n", answer) -} diff --git a/client/webrtc/logging.cpp b/client/webrtc/logging.cpp index c53bee45..38f14a7f 100644 --- a/client/webrtc/logging.cpp +++ b/client/webrtc/logging.cpp @@ -18,13 +18,10 @@ #include "logging.h" -using namespace std; -using namespace rtc; - class LogSink : public rtc::LogSink { protected: - void OnLogMessage(const string &message, LoggingSeverity severity, const char *tag) { - auto messageStr = (string)message; + void OnLogMessage(const std::string &message, rtc::LoggingSeverity severity, const char *tag) { + auto messageStr = (std::string)message; if (messageStr.back() = '\n') { messageStr.pop_back(); } @@ -34,8 +31,8 @@ class LogSink : public rtc::LogSink { onLogMessage(severity, (char *)messageStr.c_str(), (char *)tag); } - void OnLogMessage(const string &message, LoggingSeverity severity) { - auto messageStr = (string)message; + void OnLogMessage(const std::string &message, rtc::LoggingSeverity severity) { + auto messageStr = (std::string)message; if (messageStr.back() = '\n') { messageStr.pop_back(); } @@ -45,8 +42,8 @@ class LogSink : public rtc::LogSink { onLogMessage(severity, (char *)messageStr.c_str(), nullptr); } - void OnLogMessage(const string &message) { - auto messageStr = (string)message; + void OnLogMessage(const std::string &message) { + auto messageStr = (std::string)message; if (messageStr.back() = '\n') { messageStr.pop_back(); } @@ -56,8 +53,8 @@ class LogSink : public rtc::LogSink { onLogMessage(rtc::LS_INFO, (char *)messageStr.c_str(), nullptr); } - void OnLogMessage(absl::string_view message, LoggingSeverity severity, const char *tag) { - auto messageStr = (string)message; + void OnLogMessage(absl::string_view message, rtc::LoggingSeverity severity, const char *tag) { + auto messageStr = (std::string)message; if (messageStr.back() = '\n') { messageStr.pop_back(); } @@ -67,8 +64,8 @@ class LogSink : public rtc::LogSink { onLogMessage(severity, (char *)messageStr.c_str(), (char *)tag); } - void OnLogMessage(absl::string_view message, LoggingSeverity severity) { - auto messageStr = (string)message; + void OnLogMessage(absl::string_view message, rtc::LoggingSeverity severity) { + auto messageStr = (std::string)message; if (messageStr.back() = '\n') { messageStr.pop_back(); } @@ -79,7 +76,7 @@ class LogSink : public rtc::LogSink { } void OnLogMessage(absl::string_view message) { - auto messageStr = (string)message; + auto messageStr = (std::string)message; if (messageStr.back() = '\n') { messageStr.pop_back(); } @@ -91,14 +88,14 @@ class LogSink : public rtc::LogSink { }; void SetLog(int severity) { - static mutex m; + static std::mutex m; m.lock(); static ::LogSink *stream = nullptr; if (stream != nullptr) { - LogMessage::RemoveLogToStream(stream); + rtc::LogMessage::RemoveLogToStream(stream); delete stream; } stream = new ::LogSink(); - LogMessage::AddLogToStream(stream, LoggingSeverity(severity)); + rtc::LogMessage::AddLogToStream(stream, rtc::LoggingSeverity(severity)); m.unlock(); } diff --git a/client/webrtc/logging.go b/client/webrtc/logging.go index afd44774..49460e7a 100644 --- a/client/webrtc/logging.go +++ b/client/webrtc/logging.go @@ -66,7 +66,10 @@ func onLogMessage(severity C.int, messageC *C.char, tagC *C.char) { } // SetLog set logging severity and onLogMessage callback -func SetLog(severity LoggingSeverity, f func(severity LoggingSeverity, message string, tag string)) { +func SetLog( + severity LoggingSeverity, + f func(severity LoggingSeverity, message string, tag string), +) { onLogMessageGlobalRWMutex.Lock() onLogMessageGlobal = f onLogMessageGlobalRWMutex.Unlock() diff --git a/client/webrtc/peerconnection.cpp b/client/webrtc/peerconnection.cpp index de6ad801..944892bd 100644 --- a/client/webrtc/peerconnection.cpp +++ b/client/webrtc/peerconnection.cpp @@ -21,18 +21,14 @@ #include "datachannel.hpp" #include "peerconnection.h" -using namespace std; -using namespace webrtc; -using namespace rtc; - class SetLocalDescriptionObserver : public webrtc::SetLocalDescriptionObserverInterface { public: SetLocalDescriptionObserver(void *userData) : userData(userData) {} - void OnSetLocalDescriptionComplete(RTCError error) override { + void OnSetLocalDescriptionComplete(webrtc::RTCError error) override { char *err = nullptr; if (!error.ok()) { - stringstream ss; + std::stringstream ss; ss << "type:'" << ToString(error.type()) << "' message:'" << error.message() << "' error_detail:'" << ToString(error.error_detail()) << "'"; err = (char *)ss.str().c_str(); @@ -48,10 +44,10 @@ class SetRemoteDescriptionObserver : public webrtc::SetRemoteDescriptionObserver public: SetRemoteDescriptionObserver(void *userData) : userData(userData) {} - void OnSetRemoteDescriptionComplete(RTCError error) override { + void OnSetRemoteDescriptionComplete(webrtc::RTCError error) override { char *err = nullptr; if (!error.ok()) { - stringstream ss; + std::stringstream ss; ss << "type:'" << ToString(error.type()) << "' message:'" << error.message() << "' error_detail:'" << ToString(error.error_detail()) << "'"; err = (char *)ss.str().c_str(); @@ -63,20 +59,20 @@ class SetRemoteDescriptionObserver : public webrtc::SetRemoteDescriptionObserver void *userData; }; -class CreateOfferObserver : public CreateSessionDescriptionObserver { +class CreateOfferObserver : public webrtc::CreateSessionDescriptionObserver { public: CreateOfferObserver(void *userData) : userData(userData) {} protected: - void OnSuccess(SessionDescriptionInterface *desc) { - string descStr; + void OnSuccess(webrtc::SessionDescriptionInterface *desc) { + std::string descStr; desc->ToString(&descStr); onOffer((char *)descStr.c_str(), nullptr, userData); } - void OnFailure(RTCError error) { + void OnFailure(webrtc::RTCError error) { if (!error.ok()) { - stringstream ss; + std::stringstream ss; ss << "type:'" << ToString(error.type()) << "' message:'" << error.message() << "' error_detail:'" << ToString(error.error_detail()) << "'"; onOffer(nullptr, (char *)ss.str().c_str(), userData); @@ -87,20 +83,20 @@ class CreateOfferObserver : public CreateSessionDescriptionObserver { void *userData; }; -class CreateAnswerObserver : public CreateSessionDescriptionObserver { +class CreateAnswerObserver : public webrtc::CreateSessionDescriptionObserver { public: CreateAnswerObserver(void *userData) : userData(userData) {} protected: - void OnSuccess(SessionDescriptionInterface *desc) { - string descStr; + void OnSuccess(webrtc::SessionDescriptionInterface *desc) { + std::string descStr; desc->ToString(&descStr); onAnswer((char *)descStr.c_str(), nullptr, userData); } - void OnFailure(RTCError error) { + void OnFailure(webrtc::RTCError error) { if (!error.ok()) { - stringstream ss; + std::stringstream ss; ss << "type:'" << ToString(error.type()) << "' message:'" << error.message() << "' error_detail:'" << ToString(error.error_detail()) << "'"; onAnswer(nullptr, (char *)ss.str().c_str(), userData); @@ -114,8 +110,8 @@ class CreateAnswerObserver : public CreateSessionDescriptionObserver { class PeerConnectionObserver : public webrtc::PeerConnectionObserver { public: PeerConnectionObserver(void *userData) : userData(userData) { - createOfferObserver = make_ref_counted(userData); - createAnswerObserver = make_ref_counted(userData); + createOfferObserver = rtc::make_ref_counted(userData); + createAnswerObserver = rtc::make_ref_counted(userData); } void Delete() { @@ -124,20 +120,30 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { // 的同时也会释放 this } - char *Start(char **iceServers, int iceServersLen, uint16_t *minPort, uint16_t *maxPort) { - signalingThread = Thread::Create(); - auto ok = signalingThread->Start(); - if (!ok) { - return (char *)"failed to start signal thread"; + char *Start(char **iceServers, int iceServersLen, uint16_t *minPort, uint16_t *maxPort, + void *signalingThreadOutside, void *networkThreadOutside, + void *workerThreadOutside) { + if (signalingThreadOutside == nullptr) { + ownedSignalingThread = rtc::Thread::Create(); + auto ok = ownedSignalingThread->Start(); + if (!ok) { + return (char *)"signalingThread start failed"; + } + signalingThread = ownedSignalingThread.get(); + } else { + signalingThread = (rtc::Thread *)signalingThreadOutside; } - PeerConnectionFactoryDependencies dependencies; - dependencies.signaling_thread = signalingThread.get(); - auto peerConnectionFactory = CreateModularPeerConnectionFactory(move(dependencies)); + webrtc::PeerConnectionFactoryDependencies dependencies; + dependencies.signaling_thread = signalingThread; + dependencies.network_thread = (rtc::Thread *)networkThreadOutside; + dependencies.worker_thread = (rtc::Thread *)workerThreadOutside; + auto peerConnectionFactory = + webrtc::CreateModularPeerConnectionFactory(std::move(dependencies)); - PeerConnectionInterface::RTCConfiguration configuration; + webrtc::PeerConnectionInterface::RTCConfiguration configuration; if (iceServersLen > 0) { - PeerConnectionInterface::IceServer iceServer; + webrtc::PeerConnectionInterface::IceServer iceServer; for (int i = 0; i < iceServersLen; i++) { iceServer.urls.push_back(iceServers[i]); } @@ -150,12 +156,12 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { configuration.set_max_port(*maxPort); } - PeerConnectionDependencies connectionDependencies(this); + webrtc::PeerConnectionDependencies connectionDependencies(this); auto peerConnectionOrError = peerConnectionFactory->CreatePeerConnectionOrError( - configuration, move(connectionDependencies)); + configuration, std::move(connectionDependencies)); if (!peerConnectionOrError.ok()) { - stringstream ss; + std::stringstream ss; ss << "type:'" << ToString(peerConnectionOrError.error().type()) << "' message:'" << peerConnectionOrError.error().message() << "' error_detail:'" << ToString(peerConnectionOrError.error().error_detail()) << "'"; @@ -173,14 +179,14 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { void *dataChannelUserData) { char *err = nullptr; signalingThread->BlockingCall([&] { - DataChannelInit config; + webrtc::DataChannelInit config; config.negotiated = negotiated; auto dataChannelOrError = peerConnection->CreateDataChannelOrError(label, &config); if (!dataChannelOrError.ok()) { - stringstream ss; - ss << "type:'" << ToString(dataChannelOrError.error().type()) << "' message:'" - << dataChannelOrError.error().message() << "' error_detail:'" - << ToString(dataChannelOrError.error().error_detail()) << "'"; + std::stringstream ss; + ss << "type:'" << webrtc::ToString(dataChannelOrError.error().type()) + << "' message:'" << dataChannelOrError.error().message() << "' error_detail:'" + << webrtc::ToString(dataChannelOrError.error().error_detail()) << "'"; auto str = ss.str(); auto buf = calloc(str.size() + 1, 1); memcpy(buf, str.data(), str.size()); @@ -199,24 +205,24 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { void CreateOffer() { signalingThread->BlockingCall([&] { - PeerConnectionInterface::RTCOfferAnswerOptions options; + webrtc::PeerConnectionInterface::RTCOfferAnswerOptions options; peerConnection->CreateOffer(createOfferObserver.get(), options); }); } void CreateAnswer() { signalingThread->BlockingCall([&] { - PeerConnectionInterface::RTCOfferAnswerOptions options; + webrtc::PeerConnectionInterface::RTCOfferAnswerOptions options; peerConnection->CreateAnswer(createAnswerObserver.get(), options); }); } void SetDescription(int isLocal, int sdpType, char *sdp) { signalingThread->BlockingCall([&] { - SdpParseError error; - auto desc = CreateSessionDescription((SdpType)sdpType, sdp, &error); + webrtc::SdpParseError error; + auto desc = webrtc::CreateSessionDescription((webrtc::SdpType)sdpType, sdp, &error); if (desc == nullptr) { - stringstream ss; + std::stringstream ss; ss << "line:'" << error.line << "' description:'" << error.description << "'"; if (isLocal) { ::onSetLocalDescription((char *)ss.str().c_str(), userData); @@ -227,10 +233,12 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { } if (isLocal) { peerConnection->SetLocalDescription( - move(desc), make_ref_counted<::SetLocalDescriptionObserver>(userData)); + std::move(desc), + rtc::make_ref_counted<::SetLocalDescriptionObserver>(userData)); } else { peerConnection->SetRemoteDescription( - move(desc), make_ref_counted<::SetRemoteDescriptionObserver>(userData)); + std::move(desc), + rtc::make_ref_counted<::SetRemoteDescriptionObserver>(userData)); } }); } @@ -243,7 +251,7 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { } else { desc = peerConnection->remote_description(); } - string descStr; + std::string descStr; *sdpType = (int)desc->GetType(); desc->ToString(&descStr); *sdp = (char *)calloc(1, descStr.size() + 1); @@ -254,10 +262,10 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { char *AddICECandidate(char *sdpMid, int sdpMLineIndex, char *sdp) { char *err = nullptr; signalingThread->BlockingCall([&] { - SdpParseError sdpParseError; + webrtc::SdpParseError sdpParseError; auto candidate = CreateIceCandidate(sdpMid, sdpMLineIndex, sdp, &sdpParseError); if (candidate == nullptr) { - stringstream ss; + std::stringstream ss; ss << "line:'" << sdpParseError.line << "' description:'" << sdpParseError.description << "'"; auto str = ss.str(); @@ -272,11 +280,11 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { } protected: - void OnSignalingChange(PeerConnectionInterface::SignalingState new_state) { + void OnSignalingChange(webrtc::PeerConnectionInterface::SignalingState new_state) { ::onSignalingChange((int)new_state, userData); } - void OnDataChannel(scoped_refptr data_channel) { + void OnDataChannel(rtc::scoped_refptr data_channel) { auto dataChannelReleased = data_channel.release(); ::onDataChannel((char *)dataChannelReleased->label().c_str(), dataChannelReleased->id(), (void *)dataChannelReleased, userData); @@ -290,7 +298,8 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { } } - void OnStandardizedIceConnectionChange(PeerConnectionInterface::IceConnectionState new_state) { + void OnStandardizedIceConnectionChange( + webrtc::PeerConnectionInterface::IceConnectionState new_state) { ::onStandardizedICEConnectionChange((int)new_state, userData); } @@ -300,39 +309,42 @@ class PeerConnectionObserver : public webrtc::PeerConnectionObserver { (char *)error_text.c_str(), userData); } - void OnIceConnectionChange(PeerConnectionInterface::IceConnectionState new_state) { + void OnIceConnectionChange(webrtc::PeerConnectionInterface::IceConnectionState new_state) { ::onICEConnectionChange(new_state, userData); } - void OnConnectionChange(PeerConnectionInterface::PeerConnectionState new_state) { + void OnConnectionChange(webrtc::PeerConnectionInterface::PeerConnectionState new_state) { ::onConnectionChange(int(new_state), userData); } - void OnIceGatheringChange(PeerConnectionInterface::IceGatheringState new_state) { + void OnIceGatheringChange(webrtc::PeerConnectionInterface::IceGatheringState new_state) { ::onICEGatheringChange((int)new_state, userData); } - void OnIceCandidate(const IceCandidateInterface *candidate) { - string sdp; + void OnIceCandidate(const webrtc::IceCandidateInterface *candidate) { + std::string sdp; candidate->ToString(&sdp); ::onICECandidate((char *)candidate->sdp_mid().c_str(), candidate->sdp_mline_index(), (char *)sdp.c_str(), userData); } private: - scoped_refptr peerConnection; - unique_ptr signalingThread; - scoped_refptr createOfferObserver; - scoped_refptr createAnswerObserver; + rtc::scoped_refptr peerConnection; + rtc::Thread *signalingThread; + std::unique_ptr ownedSignalingThread; + rtc::scoped_refptr createOfferObserver; + rtc::scoped_refptr createAnswerObserver; void *userData; }; char *NewPeerConnection(void **peerConnectionOutside, char **iceServers, int iceServersLen, - uint16_t *minPort, uint16_t *maxPort, void *userData) { - auto peerConnectionObserver = make_ref_counted<::PeerConnectionObserver>(userData); + uint16_t *minPort, uint16_t *maxPort, void *signalingThread, + void *networkThread, void *workerThread, void *userData) { + auto peerConnectionObserver = rtc::make_ref_counted<::PeerConnectionObserver>(userData); *peerConnectionOutside = (void *)peerConnectionObserver.release(); auto err = (*(::PeerConnectionObserver **)peerConnectionOutside) - ->Start(iceServers, iceServersLen, minPort, maxPort); + ->Start(iceServers, iceServersLen, minPort, maxPort, signalingThread, + networkThread, workerThread); return err; } diff --git a/client/webrtc/peerconnection.go b/client/webrtc/peerconnection.go index 36531e07..505b7cca 100644 --- a/client/webrtc/peerconnection.go +++ b/client/webrtc/peerconnection.go @@ -160,7 +160,13 @@ type PeerConnectionConfig struct { } // NewPeerConnection 使用 peerConnectionPointer 是为了防止回调时值还没有被设置,因为是不同的线程 -func NewPeerConnection(config *PeerConnectionConfig, peerConnectionPointer **PeerConnection) (err error) { +func NewPeerConnection( + config *PeerConnectionConfig, + peerConnectionPointer **PeerConnection, + signalingThread unsafe.Pointer, + networkThread unsafe.Pointer, + workerThread unsafe.Pointer, +) (err error) { iceServers := C.malloc(C.size_t(len(config.ICEServers)) * C.size_t(unsafe.Sizeof(uintptr(0)))) iceServersPointer := (*[1<<30 - 1]*C.char)(iceServers) for i, stun := range config.ICEServers { @@ -186,7 +192,17 @@ func NewPeerConnection(config *PeerConnectionConfig, peerConnectionPointer **Pee remoteDescriptionErrChan: make(chan error, 1), } (*peerConnectionPointer).pointerID = pointer.Save(*peerConnectionPointer) - errC := C.NewPeerConnection(&(*peerConnectionPointer).peerConnection, (**C.char)(iceServers), C.int(len(config.ICEServers)), (*C.uint16_t)(config.MinPort), (*C.uint16_t)(config.MaxPort), (*peerConnectionPointer).pointerID) + errC := C.NewPeerConnection( + &(*peerConnectionPointer).peerConnection, + (**C.char)(iceServers), + C.int(len(config.ICEServers)), + (*C.uint16_t)(config.MinPort), + (*C.uint16_t)(config.MaxPort), + signalingThread, + networkThread, + workerThread, + (*peerConnectionPointer).pointerID, + ) if errC != nil { err = errors.New(C.GoString(errC)) C.free(unsafe.Pointer(errC)) @@ -332,14 +348,27 @@ func onICECandidate(sdpMid *C.char, sdpMLineIndex C.int, sdp *C.char, userData u } //export onICECandidateError -func onICECandidateError(address *C.char, port C.int, url *C.char, errorCode C.int, errorText *C.char, userData unsafe.Pointer) { +func onICECandidateError( + address *C.char, + port C.int, + url *C.char, + errorCode C.int, + errorText *C.char, + userData unsafe.Pointer, +) { p, ok := pointer.Restore(userData).(*PeerConnection) if !ok || p == nil { return } if p.config != nil && p.config.OnICECandidateError != nil { - p.config.OnICECandidateError(C.GoString(address), int(port), C.GoString(url), int(errorCode), C.GoString(errorText)) + p.config.OnICECandidateError( + C.GoString(address), + int(port), + C.GoString(url), + int(errorCode), + C.GoString(errorText), + ) } } @@ -493,7 +522,13 @@ func (p *PeerConnection) CreateDataChannel(label string, negotiated bool, config (*dataChannelPointer).pointerID = pointer.Save(*dataChannelPointer) labelC := C.CString(label) - errC := C.CreateDataChannel(&(*dataChannelPointer).dataChannel, labelC, C.bool(negotiated), (*dataChannelPointer).pointerID, p.peerConnection) + errC := C.CreateDataChannel( + &(*dataChannelPointer).dataChannel, + labelC, + C.bool(negotiated), + (*dataChannelPointer).pointerID, + p.peerConnection, + ) C.free(unsafe.Pointer(labelC)) if errC != nil { err = errors.New(C.GoString(errC)) diff --git a/client/webrtc/peerconnection.h b/client/webrtc/peerconnection.h index b30951d4..096eb4dd 100644 --- a/client/webrtc/peerconnection.h +++ b/client/webrtc/peerconnection.h @@ -22,7 +22,8 @@ extern "C" { #include char *NewPeerConnection(void **peerConnectionOutside, char **iceServers, int iceServersLen, - uint16_t *minPort, uint16_t *maxPort, void *userData); + uint16_t *minPort, uint16_t *maxPort, void *signalingThread, + void *networkThread, void *workerThread, void *userData); void DeletePeerConnection(void *peerConnection); void onSignalingChange(int new_state, void *userData); diff --git a/client/webrtc/speedtest/client/main.go b/client/webrtc/speedtest/client/main.go deleted file mode 100644 index 9cc0c0df..00000000 --- a/client/webrtc/speedtest/client/main.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2022 Institute of Software, Chinese Academy of Sciences (ISCAS) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "bufio" - "crypto/rand" - "encoding/json" - "fmt" - "os" - "time" - - "github.com/isrc-cas/gt/client/webrtc" -) - -func main() { - webrtc.SetLog(webrtc.LoggingSeverityWarning, func(severity webrtc.LoggingSeverity, message, tag string) { - fmt.Println("severity", severity.String(), "message", message, "tag", tag) - }) - go client(100 * 1024 * 1024) - select {} -} - -func client(dataNumber int) { - waitNegotiationNeeded := make(chan struct{}) - var peerConnection *webrtc.PeerConnection - var err error - peerConnectionConfig := webrtc.PeerConnectionConfig{ - ICEServers: []string{"stun:stun.l.google.com:19302"}, - OnSignalingChange: func(state webrtc.SignalingState) { - fmt.Println(state.String()) - }, - OnDataChannel: func(dataChannel *webrtc.DataChannelWithoutCallback) { - }, - OnRenegotiationNeeded: func() { - }, - OnNegotiationNeeded: func() { - close(waitNegotiationNeeded) - }, - OnICEConnectionChange: func(state webrtc.ICEConnectionState) { - }, - OnStandardizedICEConnectionChange: func(state webrtc.ICEConnectionState) { - }, - OnConnectionChange: func(state webrtc.PeerConnectionState) { - fmt.Println("peer connection", state.String()) - }, - OnICEGatheringChange: func(state webrtc.ICEGatheringState) { - fmt.Println("ice gathering", state.String()) - }, - OnICECandidate: func(iceCandidate *webrtc.ICECandidate) { - fmt.Printf("get ice candidate:'%#v'\n", iceCandidate) - }, - OnICECandidateError: func(addrss string, port int, url string, errorCode int, errorText string) { - }, - } - err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection) - if err != nil { - panic(err) - } - - var dataChannel *webrtc.DataChannel - dataChannelConfig := webrtc.DataChannelConfig{ - OnStateChange: func(state webrtc.DataState) { - fmt.Println("data channel", state.String()) - if state == webrtc.DataStateOpen { - buf := make([]byte, 1024) - _, err := rand.Read(buf) - if err != nil { - panic(err) - } - - startTime := time.Now() - dataNumberCopy := dataNumber - for dataNumberCopy > 0 { - bufLen := len(buf) - if dataNumberCopy < bufLen { - bufLen = dataNumberCopy - } - if !dataChannel.Send(buf[:bufLen]) { - panic("data channel send failed") - } - dataNumberCopy -= bufLen - } - fmt.Printf("all data send done, took %v to send %v bytes, %.2f MB/s\n", time.Since(startTime), dataNumber, float64(dataNumber)/1024/1024/float64(time.Since(startTime)/time.Second)) - } - }, - OnMessage: func(message []byte) { - }, - } - err = peerConnection.CreateDataChannel("speed test", false, &dataChannelConfig, &dataChannel) - if err != nil { - panic(err) - } - - <-waitNegotiationNeeded - offer, err := peerConnection.CreateOffer() - if err != nil { - panic(err) - } - err = peerConnection.SetLocalDescription(offer) - if err != nil { - panic(err) - } - offer = peerConnection.GetLocalDescription() - offerJSON, err := json.Marshal(offer) - if err != nil { - panic(err) - } - fmt.Printf("offer: '%v'\n", string(offerJSON)) - - fmt.Print("please enter answer: ") - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() - answerJSON := scanner.Text() - var answer webrtc.SessionDescription - err = json.Unmarshal([]byte(answerJSON), &answer) - if err != nil { - panic(err) - } - err = peerConnection.SetRemoteDescription(&answer) - if err != nil { - panic(err) - } -} diff --git a/client/webrtc/speedtest/echoServer/main.go b/client/webrtc/speedtest/echoServer/main.go deleted file mode 100644 index a4d2954e..00000000 --- a/client/webrtc/speedtest/echoServer/main.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) 2022 Institute of Software, Chinese Academy of Sciences (ISCAS) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "bufio" - "encoding/json" - "fmt" - "os" - - "github.com/isrc-cas/gt/client/webrtc" -) - -func main() { - webrtc.SetLog(webrtc.LoggingSeverityWarning, func(severity webrtc.LoggingSeverity, message, tag string) { - fmt.Println("severity", severity.String(), "message", message, "tag", tag) - }) - go echoServer() - select {} -} - -func echoServer() { - waitICEGatheringComplete := make(chan struct{}) - var peerConnection *webrtc.PeerConnection - var err error - peerConnectionConfig := webrtc.PeerConnectionConfig{ - ICEServers: []string{"stun:stun.l.google.com:19302"}, - OnSignalingChange: func(state webrtc.SignalingState) { - fmt.Println("signaling", state.String()) - }, - OnDataChannel: func(dataChannelWithoutCallback *webrtc.DataChannelWithoutCallback) { - var dataChannel *webrtc.DataChannel - dataChannelConfig := webrtc.DataChannelConfig{ - OnStateChange: func(state webrtc.DataState) { - fmt.Println("data channel", state.String()) - }, - OnMessage: func(message []byte) { - if !dataChannel.Send(message) { - panic("data channel send failed") - } - }, - } - dataChannelWithoutCallback.SetCallback(&dataChannelConfig, &dataChannel) - }, - OnRenegotiationNeeded: func() { - }, - OnNegotiationNeeded: func() { - }, - OnICEConnectionChange: func(state webrtc.ICEConnectionState) { - }, - OnStandardizedICEConnectionChange: func(state webrtc.ICEConnectionState) { - }, - OnConnectionChange: func(state webrtc.PeerConnectionState) { - fmt.Println("peer connection", state.String()) - }, - OnICEGatheringChange: func(state webrtc.ICEGatheringState) { - fmt.Println("ice gathering", state.String()) - if state == webrtc.ICEGatheringStateComplete { - close(waitICEGatheringComplete) - } - }, - OnICECandidate: func(iceCandidate *webrtc.ICECandidate) { - fmt.Printf("get ice candidate:'%#v'\n", iceCandidate) - }, - OnICECandidateError: func(addrss string, port int, url string, errorCode int, errorText string) { - }, - } - err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection) - if err != nil { - panic(err) - } - - fmt.Print("please entern offer: ") - scanner := bufio.NewScanner(os.Stdin) - scanner.Scan() - offerJSON := scanner.Text() - fmt.Println(offerJSON) - var offer webrtc.SessionDescription - err = json.Unmarshal([]byte(offerJSON), &offer) - if err != nil { - panic(err) - } - err = peerConnection.SetRemoteDescription(&offer) - if err != nil { - panic(err) - } - - answer, err := peerConnection.CreateAnswer() - if err != nil { - panic(err) - } - err = peerConnection.SetLocalDescription(answer) - if err != nil { - panic(err) - } - <-waitICEGatheringComplete - answer = peerConnection.GetLocalDescription() - answerJSON, err := json.Marshal(answer) - if err != nil { - panic(err) - } - fmt.Printf("answer: '%v'\n", string(answerJSON)) -} diff --git a/client/webrtc/threadpool.cpp b/client/webrtc/threadpool.cpp new file mode 100644 index 00000000..83d55378 --- /dev/null +++ b/client/webrtc/threadpool.cpp @@ -0,0 +1,105 @@ +// Copyright (c) 2022 Institute of Software, Chinese Academy of Sciences (ISCAS) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include + +#include "threadpool.h" + +// 实现 webrtc 的线程池管理,如果每个连接都分配三个线程的话, +// 在高并发的情况下性能较差。想象一个圆盘,盘面上有很多个扇形, +// 每个扇形代表一个线程,这些扇形刚开始是白色的,表示还没有初始化。 +// 线程初始化之后扇形的颜色就变成了黑色,然后圆盘会转动起来, +// 越来越多的扇形变成了黑色。 +class ThreadPool { + public: + ThreadPool(uint32_t threadNum) { + threads = std::vector>(threadNum); + socketThreads = std::vector>(threadNum); + } + + rtc::Thread *GetThread() { + // 先用读锁的方式从线程池中获取一个线程,如果此线程没有初始化,那么就用写锁的方式初始化它 + std::shared_lock sharedLock(threadsSharedMutex); + if (!threads[threadIndex]) { + sharedLock.unlock(); + std::unique_lock uniqueLock(threadsSharedMutex); + if (!threads[threadIndex]) { + threads[threadIndex] = rtc::Thread::Create(); + threads[threadIndex]->Start(); + } + uniqueLock.unlock(); + sharedLock.lock(); + } + + // 指向下一个线程 + auto result = threads[threadIndex].get(); + threadIndex = (threadIndex + 1) % threads.size(); + + // 返回线程 + return result; + } + + rtc::Thread *GetSocketThread() { + // 先用读锁的方式从线程池中获取一个线程,如果此线程没有初始化,那么就用写锁的方式初始化它 + std::shared_lock sharedLock(threadsSharedMutex); + if (!socketThreads[threadIndex]) { + sharedLock.unlock(); + std::unique_lock uniqueLock(threadsSharedMutex); + if (!socketThreads[threadIndex]) { + socketThreads[threadIndex] = rtc::Thread::CreateWithSocketServer(); + socketThreads[threadIndex]->Start(); + } + uniqueLock.unlock(); + sharedLock.lock(); + } + + // 指向下一个线程 + auto result = socketThreads[threadIndex].get(); + threadIndex = (threadIndex + 1) % socketThreads.size(); + + // 返回线程 + return result; + } + + ~ThreadPool() { + for (auto &thread : threads) { + if (thread) { + thread->Stop(); + } + } + } + + private: + std::vector> threads; + std::vector> socketThreads; + std::shared_mutex threadsSharedMutex; + std::atomic_uint32_t threadIndex = 0; +}; + +void *NewThreadPool(uint32_t threadNum) { + return new ThreadPool(threadNum); +} + +void *GetThreadPoolThread(void *threadPool) { return ((ThreadPool *)threadPool)->GetThread(); } + +void *GetThreadPoolSocketThread(void *threadPool) { + return ((ThreadPool *)threadPool)->GetSocketThread(); +} + +void DeleteThreadPool(void *threadPool) { delete (ThreadPool *)threadPool; } diff --git a/client/webrtc/threadpool.go b/client/webrtc/threadpool.go new file mode 100644 index 00000000..cdba7055 --- /dev/null +++ b/client/webrtc/threadpool.go @@ -0,0 +1,41 @@ +// Copyright (c) 2022 Institute of Software, Chinese Academy of Sciences (ISCAS) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package webrtc + +/* +#include "threadpool.h" +*/ +import "C" +import "unsafe" + +type ThreadPool struct { + p unsafe.Pointer +} + +func NewThreadPool(threadNum uint32) *ThreadPool { + return &ThreadPool{p: C.NewThreadPool(C.uint32_t(threadNum))} +} + +func (tp *ThreadPool) GetThread() unsafe.Pointer { + return C.GetThreadPoolThread(tp.p) +} + +func (tp *ThreadPool) GetSocketThread() unsafe.Pointer { + return C.GetThreadPoolSocketThread(tp.p) +} + +func (tp *ThreadPool) Close() { + C.DeleteThreadPool(tp.p) +} diff --git a/client/webrtc/threadpool.h b/client/webrtc/threadpool.h new file mode 100644 index 00000000..b645f047 --- /dev/null +++ b/client/webrtc/threadpool.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 Institute of Software, Chinese Academy of Sciences (ISCAS) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THREADPOOL_H +#define THREADPOOL_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +void *NewThreadPool(uint32_t threadNum); +void *GetThreadPoolThread(void *threadPool); +void *GetThreadPoolSocketThread(void *threadPool); +void DeleteThreadPool(void *threadPool); + +#ifdef __cplusplus +} +#endif + +#endif /* THREADPOOL_H */ \ No newline at end of file diff --git a/predef/debug.go b/predef/debug.go index 0807a70b..7f6cde38 100644 --- a/predef/debug.go +++ b/predef/debug.go @@ -13,7 +13,6 @@ // limitations under the License. //go:build !release -// +build !release package predef @@ -27,7 +26,7 @@ import ( ) // Debug enables the logs of read and write operations -var Debug = true +var Debug = false func init() { env, ok := os.LookupEnv("DEBUG_REQ") diff --git a/predef/release.go b/predef/release.go index 1a251696..5993e9e4 100644 --- a/predef/release.go +++ b/predef/release.go @@ -13,14 +13,8 @@ // limitations under the License. //go:build release -// +build release package predef -import ( - // used for prof - _ "net/http/pprof" -) - // Debug enables the logs of read and write operations const Debug = false diff --git a/test/p2p_test.go b/test/p2p_test.go index 0b11571f..c22567dd 100644 --- a/test/p2p_test.go +++ b/test/p2p_test.go @@ -100,7 +100,7 @@ func TestP2PGetOffer(t *testing.T) { OnICECandidateError: func(addrss string, port int, url string, errorCode int, errorText string) { }, } - err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection) + err = webrtc.NewPeerConnection(&peerConnectionConfig, &peerConnection, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -469,7 +469,7 @@ func initOffer(t *testing.T, addr string) (*webrtc.PeerConnection, context.Conte }, } var err error - err = webrtc.NewPeerConnection(&config, &peerConnection) + err = webrtc.NewPeerConnection(&config, &peerConnection, nil, nil, nil) if err != nil { t.Fatal(err) }