Skip to content

Commit

Permalink
Merge pull request #1161 from nats-io/fix_sub_req_failed_on_client_close
Browse files Browse the repository at this point in the history
[FIXED] Fail faster subscription requests for invalid clients
  • Loading branch information
kozlovic authored Feb 25, 2021
2 parents 6405a5e + 7b73297 commit 5e54581
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 31 deletions.
55 changes: 34 additions & 21 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2778,7 +2778,10 @@ func (s *StanServer) initInternalSubs(createPub bool) error {
}
}
// Receive subscription requests from clients.
s.subSub, err = s.createSub(s.info.Subscribe, s.processSubscriptionRequest, "subscribe request")
// Don't make this subscription unlimited because we would rather drop
// subscriptions requests than adding to an already possibly overloaded
// server.
s.subSub, err = s.createSubWithUnlimited(s.info.Subscribe, s.processSubscriptionRequest, "subscribe request", false)
if err != nil {
return err
}
Expand Down Expand Up @@ -2861,11 +2864,17 @@ func (s *StanServer) unsubscribeInternalSubs() {
}

func (s *StanServer) createSub(subj string, f nats.MsgHandler, errTxt string) (*nats.Subscription, error) {
return s.createSubWithUnlimited(subj, f, errTxt, true)
}

func (s *StanServer) createSubWithUnlimited(subj string, f nats.MsgHandler, errTxt string, setUnlimited bool) (*nats.Subscription, error) {
sub, err := s.nc.Subscribe(subj, f)
if err != nil {
return nil, fmt.Errorf("could not subscribe to %s subject: %v", errTxt, err)
}
sub.SetPendingLimits(-1, -1)
if setUnlimited {
sub.SetPendingLimits(-1, -1)
}
return sub, nil
}

Expand Down Expand Up @@ -3223,17 +3232,15 @@ func (s *StanServer) checkClientHealth(clientID string) {
// close the client (connection). This locks the
// client object internally so unlock here.
client.Unlock()
s.barrier(func() {
// If clustered, thread operations through Raft.
if s.isClustered {
if err := s.replicateConnClose(&pb.CloseRequest{ClientID: clientID}); err != nil {
s.log.Errorf("[Client:%s] Failed to replicate disconnect on heartbeat expiration: %v",
clientID, err)
}
} else {
s.closeClient(clientID)
// If clustered, thread operations through Raft.
if s.isClustered {
if err := s.replicateConnClose(&pb.CloseRequest{ClientID: clientID}, false); err != nil {
s.log.Errorf("[Client:%s] Failed to replicate disconnect on heartbeat expiration: %v",
clientID, err)
}
})
} else {
s.closeClient(clientID)
}
return
}
} else {
Expand Down Expand Up @@ -3305,7 +3312,7 @@ func (s *StanServer) processCloseRequest(m *nats.Msg) {
var err error
// If clustered, thread operations through Raft.
if s.isClustered {
err = s.replicateConnClose(req)
err = s.replicateConnClose(req, true)
} else {
err = s.closeClient(req.ClientID)
}
Expand All @@ -3316,12 +3323,14 @@ func (s *StanServer) processCloseRequest(m *nats.Msg) {
})
}

func (s *StanServer) replicateConnClose(req *pb.CloseRequest) error {
// Go through the list of subscriptions and possibly
// flush the pending replication of sent/ack.
subs := s.clients.getSubs(req.ClientID)
for _, sub := range subs {
s.endSubSentAndAckReplication(sub, false)
func (s *StanServer) replicateConnClose(req *pb.CloseRequest, flushSubAcks bool) error {
if flushSubAcks {
// Go through the list of subscriptions and possibly
// flush the pending replication of sent/ack.
subs := s.clients.getSubs(req.ClientID)
for _, sub := range subs {
s.endSubSentAndAckReplication(sub, false)
}
}

op := &spb.RaftOperation{
Expand Down Expand Up @@ -4873,7 +4882,7 @@ func (s *StanServer) replicateSub(c *channel, sr *pb.SubscriptionRequest, ackInb
func (s *StanServer) addSubscription(ss *subStore, sub *subState) error {
// Store in client
if !s.clients.addSub(sub.ClientID, sub) {
return fmt.Errorf("can't find clientID: %v", sub.ClientID)
return ErrUnknownClient
}
// Store this subscription in subStore
if err := ss.Store(sub); err != nil {
Expand All @@ -4887,7 +4896,7 @@ func (s *StanServer) addSubscription(ss *subStore, sub *subState) error {
func (s *StanServer) updateDurable(ss *subStore, sub *subState, clientID string) error {
// Store in the client
if !s.clients.addSub(clientID, sub) {
return fmt.Errorf("can't find clientID: %v", clientID)
return ErrUnknownClient
}
// Update this subscription in the store
sub.Lock()
Expand Down Expand Up @@ -5151,6 +5160,10 @@ func (s *StanServer) processSubscriptionRequest(m *nats.Msg) {
s.sendSubscriptionResponseErr(m.Reply, ErrInvalidSubReq)
return
}
} else if !s.clients.isValid(sr.ClientID, nil) {
// If client is not known, fail the request.
s.sendSubscriptionResponseErr(m.Reply, ErrUnknownClient)
return
}

var (
Expand Down
16 changes: 9 additions & 7 deletions server/server_req_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2016-2019 The NATS Authors
// Copyright 2016-2021 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestInvalidSubRequest(t *testing.T) {
defer nc.Close()

// This test is very dependent on the validity tests performed
// in StanServer.processSubscriptionRequest(). Any cahnge there
// in StanServer.processSubscriptionRequest(). Any change there
// may require changes here.

// Create empty request
Expand Down Expand Up @@ -201,20 +201,22 @@ func TestInvalidSubRequest(t *testing.T) {
}

// Test Queue Group DurableName
sc := NewDefaultConnection(t)
defer sc.Close()
req.Subject = "foo"
req.QGroup = "queue"
req.DurableName = "dur:name"
if err := sendInvalidSubRequest(s, nc, req, ErrInvalidDurName); err != nil {
t.Fatalf("%v", err)
}
sc.Close()

// Reset those
req.QGroup = ""
req.DurableName = ""

// Now we should have an error that says that we can't find client ID
// (that is, client was not registered).
if err := sendInvalidSubRequest(s, nc, req, fmt.Errorf("can't find clientID: %v", clientName)); err != nil {
// Now we should have an error that says that we have an unknown client ID.
if err := sendInvalidSubRequest(s, nc, req, ErrUnknownClient); err != nil {
t.Fatalf("%v", err)
}

Expand All @@ -227,7 +229,7 @@ func TestInvalidSubRequest(t *testing.T) {
}

// Create a durable
sc := NewDefaultConnection(t)
sc = NewDefaultConnection(t)
defer sc.Close()
dur, err := sc.Subscribe("foo", func(_ *stan.Msg) {}, stan.DurableName("dur"))
if err != nil {
Expand All @@ -245,7 +247,7 @@ func TestInvalidSubRequest(t *testing.T) {
req.ClientID = clientName
req.Subject = "foo"
req.DurableName = "dur"
if err := sendInvalidSubRequest(s, nc, req, fmt.Errorf("can't find clientID: %v", clientName)); err != nil {
if err := sendInvalidSubRequest(s, nc, req, ErrUnknownClient); err != nil {
t.Fatalf("%v", err)
}
}
Expand Down
78 changes: 77 additions & 1 deletion server/server_sub_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2016-2019 The NATS Authors
// Copyright 2016-2021 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand All @@ -24,6 +24,7 @@ import (
"github.com/nats-io/nats-streaming-server/spb"
"github.com/nats-io/nats-streaming-server/stores"
"github.com/nats-io/nats.go"
"github.com/nats-io/nuid"
"github.com/nats-io/stan.go"
"github.com/nats-io/stan.go/pb"
)
Expand Down Expand Up @@ -1341,3 +1342,78 @@ func TestSubCloseByInbox(t *testing.T) {
t.Fatalf("Should not be any subscription, got %v", subs)
}
}

func TestSubRequestsFailedIfClientClosed(t *testing.T) {
sOpts := GetDefaultOptions()
sOpts.ID = clusterName
sOpts.ClientHBInterval = 15 * time.Millisecond
sOpts.ClientHBTimeout = 15 * time.Millisecond
sOpts.ClientHBFailCount = 1
sOpts.StoreLimits.SubStoreLimits.MaxSubscriptions = 0
nOpts := DefaultNatsServerOptions
s := runServerWithOpts(t, sOpts, &nOpts)
defer s.Shutdown()

// Use a bare NATS connection to send incorrect requests
nc, err := nats.Connect(nats.DefaultURL)
if err != nil {
t.Fatalf("Unexpected error on connect: %v", err)
}
defer nc.Close()

sub, err := nc.SubscribeSync("subreply")
if err != nil {
t.Fatalf("Error on subscribe: %v", err)
}
sub.SetPendingLimits(-1, -1)

req := &pb.ConnectRequest{ClientID: clientName, HeartbeatInbox: "hbInbox"}
b, _ := req.Marshal()
resp, err := nc.Request(s.info.Discovery, b, time.Second)
if err != nil {
t.Fatalf("Unexpected error on publishing request: %v", err)
}
r := &pb.ConnectResponse{}
err = r.Unmarshal(resp.Data)
if err != nil {
t.Fatalf("Unexpected response object: %v", err)
}
if r.Error != "" {
t.Fatalf("Unexpected error: %v", r.Error)
}

s.channels.Lock()

for i := 0; i < 1000; i++ {
req := &pb.SubscriptionRequest{
ClientID: clientName,
Subject: "foo",
Inbox: nuid.Next(),
MaxInFlight: 1,
AckWaitInSecs: 30,
}
b, _ := req.Marshal()
if err := nc.PublishRequest(s.info.Subscribe, sub.Subject, b); err != nil {
t.Fatalf("Error on request: %v", err)
}
}

s.channels.Unlock()

for {
msg, err := sub.NextMsg(250 * time.Millisecond)
if err != nil {
break
}

rply := &pb.SubscriptionResponse{}
rply.Unmarshal(msg.Data)
if rply.Error == "" {
continue
}
if rply.Error != ErrUnknownClient.Error() {
t.Fatalf("Expected error %q, got %q", ErrUnknownClient, rply.Error)
}
break
}
}
18 changes: 16 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2016-2019 The NATS Authors
// Copyright 2016-2021 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -1630,6 +1630,9 @@ func TestFileSliceMaxBytesCmdLine(t *testing.T) {
}

func TestInternalSubsLimits(t *testing.T) {
setPartitionsVarsForTest()
defer resetDefaultPartitionsVars()

cleanupDatastore(t)
defer cleanupDatastore(t)
cleanupRaftLog(t)
Expand Down Expand Up @@ -1659,13 +1662,20 @@ func TestInternalSubsLimits(t *testing.T) {
s := runServerWithOpts(t, o, nil)
defer s.Shutdown()

switch test.name {
case "clustered":
getLeader(t, time.Second, s)
case "ft":
getFTActiveServer(t, s)
default:
}

s.mu.Lock()
defer s.mu.Unlock()

subs := []*nats.Subscription{
s.connectSub,
s.pubSub,
s.subSub,
s.subUnsubSub,
s.subCloseSub,
s.closeSub,
Expand All @@ -1683,6 +1693,10 @@ func TestInternalSubsLimits(t *testing.T) {
sub.Subject, err, count, sz)
}
}
// The subscription on "client subscription requests" should not be unlimited.
if count, sz, err := s.subSub.PendingLimits(); err != nil || count == -1 || sz == -1 {
t.Fatalf("The subSub subscription should not be unlimited: err=%v count=%v sz=%v", err, count, sz)
}
})
}
}

0 comments on commit 5e54581

Please sign in to comment.