Skip to content

Commit

Permalink
Merge pull request #1205 from nats-io/cluster_raft_tport_shutdown
Browse files Browse the repository at this point in the history
Ensure RAFT connection are closed if transport is closed
  • Loading branch information
kozlovic authored Jul 21, 2021
2 parents da8d394 + 2acea31 commit 81a1982
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
18 changes: 18 additions & 0 deletions server/raft_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ const (
natsLogAppName = "raft-nats"
)

var errTransportShutdown = errors.New("raft-nats: transport is being shutdown")

type natsRaftConnCreator func(name string) (*nats.Conn, error)

// natsAddr implements the net.Addr interface. An address for the NATS
Expand Down Expand Up @@ -305,6 +307,7 @@ type natsStreamLayer struct {
logger hclog.Logger
conns map[*natsConn]struct{}
mu sync.Mutex
closed bool
// This is the timeout we will use for flush and dial (request timeout),
// not the timeout that RAFT will use to call SetDeadline.
dfTimeout time.Duration
Expand Down Expand Up @@ -417,6 +420,11 @@ func (n *natsStreamLayer) Dial(address raft.ServerAddress, timeout time.Duration
}

n.mu.Lock()
if n.closed {
n.mu.Unlock()
peerConn.Close()
return nil, errTransportShutdown
}
n.conns[peerConn] = struct{}{}
n.mu.Unlock()
return peerConn, nil
Expand Down Expand Up @@ -486,6 +494,11 @@ func (n *natsStreamLayer) Accept() (net.Conn, error) {
continue
}
n.mu.Lock()
if n.closed {
n.mu.Unlock()
peerConn.Close()
return nil, errTransportShutdown
}
n.conns[peerConn] = struct{}{}
n.mu.Unlock()
return peerConn, nil
Expand All @@ -494,6 +507,11 @@ func (n *natsStreamLayer) Accept() (net.Conn, error) {

func (n *natsStreamLayer) Close() error {
n.mu.Lock()
if n.closed {
n.mu.Unlock()
return nil
}
n.closed = true
nc := n.conn
// Do not set nc.conn to nil since it is accessed in some functions
// without the stream layer lock
Expand Down
86 changes: 86 additions & 0 deletions server/raft_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"io"
"net"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -924,3 +925,88 @@ func TestRAFTTransportConnReader(t *testing.T) {
t.Fatal("Accept() did not exit")
}
}

func TestRAFTTransportDialAcceptCloseConnOnTransportClosed(t *testing.T) {
s := runRaftTportServer()
defer s.Shutdown()

nc1 := newNatsConnection(t)
defer nc1.Close()
stream1, err := newNATSStreamLayer("a", nc1, newTestLogger(t), 2*time.Second, nil)
if err != nil {
t.Fatalf("Error creating stream: %v", err)
}
defer stream1.Close()

nc2 := newNatsConnection(t)
defer nc2.Close()
stream2, err := newNATSStreamLayer("b", nc2, newTestLogger(t), 2*time.Second, nil)
if err != nil {
t.Fatalf("Error creating stream: %v", err)
}
defer stream2.Close()

accepted := make(chan *natsConn, 101)
go func() {
for {
c, err := stream2.Accept()
if err != nil {
accepted <- nil
return
}
accepted <- c.(*natsConn)
}
}()

ch := make(chan bool)
dialed := make(chan net.Conn, 101)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
c, err := stream1.Dial("b", 250*time.Millisecond)
if err != nil {
return
}
dialed <- c
select {
case <-ch:
return
default:
}
}
}()
time.Sleep(50 * time.Millisecond)
stream1.Close()
stream2.Close()
close(ch)
wg.Wait()

stream1.mu.Lock()
l1 := len(stream1.conns)
stream1.mu.Unlock()
stream2.mu.Lock()
l2 := len(stream2.conns)
stream2.mu.Unlock()

for i := 0; i < 100; i++ {
select {
case c := <-dialed:
if c != nil {
c.Close()
}
default:
}
select {
case c := <-accepted:
if c != nil {
c.Close()
}
default:
}
}
if l1 > 0 || l2 > 0 {
t.Fatalf("Connections were added after streams were closed: %v/%v", l1, l2)
}
}

0 comments on commit 81a1982

Please sign in to comment.