diff --git a/pkg/networking/session.go b/pkg/networking/session.go index b0f760555..d0b1d8057 100644 --- a/pkg/networking/session.go +++ b/pkg/networking/session.go @@ -12,6 +12,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" "github.com/wavesplatform/gowaves/pkg/execution" @@ -36,10 +37,9 @@ type Session struct { sendLock sync.Mutex // Guards the sendCh. sendCh chan *sendPacket // sendCh is used to send data to the connection. - establishedLock sync.Mutex // Guards the established field. - established bool // Indicates that incoming Handshake was successfully accepted. - shutdownLock sync.Mutex // Guards the shutdown field. - shutdown bool // shutdown is used to safely close the Session. + receiving atomic.Bool // Indicates that receiveLoop already running. + established atomic.Bool // Indicates that incoming Handshake was successfully accepted. + shutdown sync.Once // shutdown is used to safely close the Session. } // NewSession is used to construct a new session. @@ -109,29 +109,24 @@ func (s *Session) RemoteAddr() net.Addr { // Close is used to close the session. It is safe to call Close multiple times from different goroutines, // subsequent calls do nothing. func (s *Session) Close() error { - s.shutdownLock.Lock() - defer s.shutdownLock.Unlock() - - if s.shutdown { - return nil // Fast path - session already closed. - } - s.shutdown = true - - s.logger.Debug("Closing session") - clErr := s.conn.Close() // Close the underlying connection. - if clErr != nil { - s.logger.Warn("Failed to close underlying connection", "error", clErr) - } - s.logger.Debug("Underlying connection closed") + var err error + s.shutdown.Do(func() { + s.logger.Debug("Closing session") + clErr := s.conn.Close() // Close the underlying connection. + if clErr != nil { + s.logger.Warn("Failed to close underlying connection", "error", clErr) + } + s.logger.Debug("Underlying connection closed") - s.cancel() // Cancel the underlying context to interrupt the loops. + s.cancel() // Cancel the underlying context to interrupt the loops. - s.logger.Debug("Waiting for loops to finish") - err := s.g.Wait() // Wait for loops to finish. + s.logger.Debug("Waiting for loops to finish") + err = s.g.Wait() // Wait for loops to finish. - err = errors.Join(err, clErr) // Combine loops finalization errors with connection close error. + err = errors.Join(err, clErr) // Combine loops finalization errors with connection close error. - s.logger.Debug("Session closed", "error", err) + s.logger.Debug("Session closed", "error", err) + }) return err } @@ -253,9 +248,9 @@ func (s *Session) sendLoop() error { // receiveLoop continues to receive data until a fatal error is encountered or underlying connection is closed. // Receive loop works after handshake and accepts only length-prepended messages. func (s *Session) receiveLoop() error { - s.establishedLock.Lock() // Prevents from running multiple receiveLoops. - defer s.establishedLock.Unlock() - + if !s.receiving.CompareAndSwap(false, true) { + return nil // Prevent running multiple receive loops. + } for { if err := s.receive(); err != nil { if errors.Is(err, ErrConnectionClosedOnRead) { @@ -268,7 +263,7 @@ func (s *Session) receiveLoop() error { } func (s *Session) receive() error { - if s.established { + if s.established.Load() { hdr := s.config.protocol.EmptyHeader() return s.readMessage(hdr) } @@ -295,7 +290,7 @@ func (s *Session) readHandshake() error { return ErrUnacceptableHandshake } // Handshake is acceptable, we can switch the session into established state. - s.established = true + s.established.Store(true) s.config.handler.OnHandshake(s, hs) return nil }