diff --git a/server/clustering.go b/server/clustering.go index 1f2dec07..c6e98ce5 100644 --- a/server/clustering.go +++ b/server/clustering.go @@ -523,7 +523,7 @@ func (r *raftFSM) Apply(l *raft.Log) interface{} { msg.Sequence, msg.Subject, err)) } } - return nil + return c.store.Msgs.Flush() case spb.RaftOperation_Connect: // Client connection create replication. return s.processConnect(op.ClientConnect.Request, op.ClientConnect.Refresh) diff --git a/server/clustering_test.go b/server/clustering_test.go index 34c257c7..3ce67c03 100644 --- a/server/clustering_test.go +++ b/server/clustering_test.go @@ -15,6 +15,7 @@ package server import ( "bytes" + "database/sql" "encoding/json" "fmt" "io/ioutil" @@ -64,10 +65,10 @@ func cleanupRaftLog(t *testing.T) { func shutdownAndCleanupState(t *testing.T, s *StanServer, nodeID string) { t.Helper() s.Shutdown() + os.RemoveAll(filepath.Join(defaultRaftLog, nodeID)) switch persistentStoreType { case stores.TypeFile: os.RemoveAll(filepath.Join(defaultDataStore, nodeID)) - os.RemoveAll(filepath.Join(defaultRaftLog, nodeID)) case stores.TypeSQL: test.CleanupSQLDatastore(t, testSQLDriver, testSQLSource+"_"+nodeID) default: @@ -6383,3 +6384,69 @@ func TestClusteringRestoreSnapshotMsgsBailIfNoLeader(t *testing.T) { } } } + +func TestClusteringSQLMsgStoreFlushed(t *testing.T) { + if !doSQL { + t.SkipNow() + } + + cleanupDatastore(t) + defer cleanupDatastore(t) + cleanupRaftLog(t) + defer cleanupRaftLog(t) + + // For this test, use a central NATS server. + ns := natsdTest.RunDefaultServer() + defer ns.Shutdown() + + // Configure first server + s1sOpts := getTestDefaultOptsForClustering("a", true) + s1 := runServerWithOpts(t, s1sOpts, nil) + defer s1.Shutdown() + + // Configure second server. + s2sOpts := getTestDefaultOptsForClustering("b", false) + s2 := runServerWithOpts(t, s2sOpts, nil) + defer s2.Shutdown() + + getLeader(t, 10*time.Second, s1, s2) + + sc := NewDefaultConnection(t) + defer sc.Close() + + ch := make(chan bool, 1) + count := 0 + // Use less than SQLStore's sqlMsgCacheLimit + total := 500 + ah := func(gui string, err error) { + count++ + if count == total { + ch <- true + } + } + for i := 0; i < total; i++ { + if _, err := sc.PublishAsync("foo", []byte("hello"), ah); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Did not get all our acks") + } + + db, err := sql.Open(testSQLDriver, testSQLSource+"_b") + if err != nil { + t.Fatalf("Error opening db: %v", err) + } + defer db.Close() + r := db.QueryRow("SELECT COUNT(seq) FROM Messages") + count = 0 + if err := r.Scan(&count); err != nil { + t.Fatalf("Error on scan: %v", err) + } + if count == 0 { + t.Fatalf("Expected some messages, got none") + } +} diff --git a/server/server_delivery_test.go b/server/server_delivery_test.go index ed7ad955..0cfe6a19 100644 --- a/server/server_delivery_test.go +++ b/server/server_delivery_test.go @@ -257,6 +257,9 @@ func TestDeliveryWithGapsInSequence(t *testing.T) { } func TestPersistentStoreSQLSubsPendingRows(t *testing.T) { + if !doSQL { + t.SkipNow() + } source := testSQLSource if persistentStoreType != stores.TypeSQL { // If not running tests with `-persistent_store sql`, diff --git a/server/server_test.go b/server/server_test.go index b9fbf730..97dae4ee 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -81,13 +81,13 @@ var ( testSQLSourceAdmin = testDefaultMySQLSourceAdmin testSQLDatabaseName = testDefaultDatabaseName testDBSuffixes = []string{"", "_a", "_b", "_c"} + doSQL = false ) func TestMain(m *testing.M) { var ( bst string pst string - doSQL bool sqlCreateDb bool sqlDeleteDb bool ) diff --git a/stores/sqlstore.go b/stores/sqlstore.go index b4ce934e..62c4272b 100644 --- a/stores/sqlstore.go +++ b/stores/sqlstore.go @@ -181,6 +181,10 @@ const ( // Number of missed update interval after which the lock is assumed // lost and another instance can update it. sqlDefaultLockLostCount = 3 + + // Limit of number of messages in the cache before message store + // is automatically flushed on a Store() call. + sqlDefaultMsgCacheLimit = 1024 ) // These are initialized based on the constants that have reasonable values. @@ -195,6 +199,7 @@ var ( sqlLockUpdateInterval = sqlDefaultLockUpdateInterval sqlLockLostCount = sqlDefaultLockLostCount sqlNoPanic = false // Used in tests to avoid go-routine to panic + sqlMsgCacheLimit = sqlDefaultMsgCacheLimit ) // SQLStoreOptions are used to configure the SQL Store. @@ -337,10 +342,11 @@ type SQLMsgStore struct { } type sqlMsgsCache struct { - msgs map[uint64]*sqlCachedMsg - head *sqlCachedMsg - tail *sqlCachedMsg - free *sqlCachedMsg + msgs map[uint64]*sqlCachedMsg + head *sqlCachedMsg + tail *sqlCachedMsg + free *sqlCachedMsg + count int } type sqlCachedMsg struct { @@ -1330,6 +1336,7 @@ func (mc *sqlMsgsCache) add(msg *pb.MsgProto, data []byte) { mc.tail.next = cachedMsg } mc.tail = cachedMsg + mc.count++ } func (mc *sqlMsgsCache) transferToFreeList() { @@ -1339,6 +1346,7 @@ func (mc *sqlMsgsCache) transferToFreeList() { } mc.head = nil mc.tail = nil + mc.count = 0 } func (mc *sqlMsgsCache) pop() *sqlCachedMsg { @@ -1349,6 +1357,7 @@ func (mc *sqlMsgsCache) pop() *sqlCachedMsg { if mc.head == nil { mc.tail = nil } + mc.count-- } return cm } @@ -1370,6 +1379,11 @@ func (ms *SQLMsgStore) Store(m *pb.MsgProto) (uint64, error) { useCache := !ms.sqlStore.opts.NoCaching if useCache { + if ms.writeCache.count >= sqlMsgCacheLimit { + if err := ms.flush(); err != nil { + return 0, err + } + } ms.writeCache.add(m, msgBytes) } else { if _, err := ms.sqlStore.preparedStmts[sqlStoreMsg].Exec(ms.channelID, seq, m.Timestamp, dataLen, msgBytes); err != nil { diff --git a/stores/sqlstore_test.go b/stores/sqlstore_test.go index ea640083..ed3cd9bf 100644 --- a/stores/sqlstore_test.go +++ b/stores/sqlstore_test.go @@ -1819,3 +1819,40 @@ func TestSQLRecoverLastSeqAfterMessagesExpired(t *testing.T) { } s.Close() } + +func TestSQLMsgCacheAutoFlush(t *testing.T) { + if !doSQL { + t.SkipNow() + } + + sqlMsgCacheLimit = 100 + defer func() { sqlMsgCacheLimit = sqlDefaultMsgCacheLimit }() + + cleanupSQLDatastore(t) + defer cleanupSQLDatastore(t) + + // Create a store with caching enabled (which is default, but invoke option here) + s, err := NewSQLStore(testLogger, testSQLDriver, testSQLSource, nil, SQLNoCaching(false)) + if err != nil { + t.Fatalf("Error creating store: %v", err) + } + defer s.Close() + + cs := storeCreateChannel(t, s, "foo") + total := sqlMsgCacheLimit + 10 + payload := make([]byte, 100) + for i := 0; i < total; i++ { + storeMsg(t, cs, "foo", uint64(i+1), payload) + } + // Check that we have started to write messages into the DB. + db := getDBConnection(t) + defer db.Close() + r := db.QueryRow("SELECT COUNT(seq) FROM Messages") + count := 0 + if err := r.Scan(&count); err != nil { + t.Fatalf("Error on scan: %v", err) + } + if count == 0 { + t.Fatalf("Expected some messages, got none") + } +}