From 5f904dd4b089885d593ee35c727ad113189a7d96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Fri, 10 May 2024 15:02:09 +0000 Subject: [PATCH] kafka/pubsub: fix LSN potentially being updated too early (#1709) LSN should not be updated before success confirmed, as intermediate value may now be read before error aborts when queue is flushed With parallelism, lsn should be updated in critical section so that if earlier invocation is lagging behind LSN doesn't skip Also fix pubsub needing EnableMessageOrdering explicitly enabled --- flow/connectors/kafka/kafka.go | 55 +++++++++++++++++++++----------- flow/connectors/kafka/qrep.go | 12 +++---- flow/connectors/pubsub/pubsub.go | 51 +++++++++++++++++++++-------- flow/connectors/pubsub/qrep.go | 20 ++++++------ 4 files changed, 90 insertions(+), 48 deletions(-) diff --git a/flow/connectors/kafka/kafka.go b/flow/connectors/kafka/kafka.go index c58da5e50f..ec7decc50c 100644 --- a/flow/connectors/kafka/kafka.go +++ b/flow/connectors/kafka/kafka.go @@ -164,18 +164,18 @@ func lvalueToKafkaRecord(ls *lua.LState, value lua.LValue) (*kgo.Record, error) return kr, nil } +type poolResult struct { + records []*kgo.Record + lsn int64 +} + func (c *KafkaConnector) createPool( ctx context.Context, script string, flowJobName string, + lastSeenLSN *atomic.Int64, queueErr func(error), -) (*utils.LPool[[]*kgo.Record], error) { - produceCb := func(_ *kgo.Record, err error) { - if err != nil { - queueErr(err) - } - } - +) (*utils.LPool[poolResult], error) { return utils.LuaPool(func() (*lua.LState, error) { ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) { _ = c.LogFlowInfo(ctx, flowJobName, s) @@ -187,25 +187,40 @@ func (c *KafkaConnector) createPool( ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) } return ls, nil - }, func(krs []*kgo.Record) { - for _, kr := range krs { - c.client.Produce(ctx, kr, produceCb) + }, func(result poolResult) { + lenRecords := int32(len(result.records)) + if lenRecords == 0 { + if lastSeenLSN != nil { + shared.AtomicInt64Max(lastSeenLSN, result.lsn) + } + } else { + recordCounter := atomic.Int32{} + recordCounter.Store(lenRecords) + for _, kr := range result.records { + c.client.Produce(ctx, kr, func(_ *kgo.Record, err error) { + if err != nil { + queueErr(err) + } else if recordCounter.Add(-1) == 0 && lastSeenLSN != nil { + shared.AtomicInt64Max(lastSeenLSN, result.lsn) + } + }) + } } }) } func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) { numRecords := atomic.Int64{} - tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) + lastSeenLSN := atomic.Int64{} queueCtx, queueErr := context.WithCancelCause(ctx) - pool, err := c.createPool(queueCtx, req.Script, req.FlowJobName, queueErr) + pool, err := c.createPool(queueCtx, req.Script, req.FlowJobName, &lastSeenLSN, queueErr) if err != nil { return nil, err } defer pool.Close() - lastSeenLSN := atomic.Int64{} + tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) flushLoopDone := make(chan struct{}) go func() { ticker := time.NewTicker(peerdbenv.PeerDBQueueFlushTimeoutSeconds()) @@ -244,12 +259,12 @@ Loop: break Loop } - pool.Run(func(ls *lua.LState) []*kgo.Record { + pool.Run(func(ls *lua.LState) poolResult { lfn := ls.Env.RawGetString("onRecord") fn, ok := lfn.(*lua.LFunction) if !ok { queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) - return nil + return poolResult{} } ls.Push(fn) @@ -257,7 +272,7 @@ Loop: err := ls.PCall(1, -1, nil) if err != nil { queueErr(fmt.Errorf("script failed: %w", err)) - return nil + return poolResult{} } args := ls.GetTop() @@ -266,7 +281,7 @@ Loop: kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args)) if err != nil { queueErr(err) - return nil + return poolResult{} } if kr != nil { if kr.Topic == "" { @@ -278,8 +293,10 @@ Loop: } ls.SetTop(0) numRecords.Add(1) - shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID()) - return results + return poolResult{ + records: results, + lsn: record.GetCheckpointID(), + } }) case <-queueCtx.Done(): diff --git a/flow/connectors/kafka/qrep.go b/flow/connectors/kafka/qrep.go index a856ad1ccf..0c472ae0ae 100644 --- a/flow/connectors/kafka/qrep.go +++ b/flow/connectors/kafka/qrep.go @@ -29,7 +29,7 @@ func (c *KafkaConnector) SyncQRepRecords( schema := stream.Schema() queueCtx, queueErr := context.WithCancelCause(ctx) - pool, err := c.createPool(queueCtx, config.Script, config.FlowJobName, queueErr) + pool, err := c.createPool(queueCtx, config.Script, config.FlowJobName, nil, queueErr) if err != nil { return 0, err } @@ -44,7 +44,7 @@ Loop: break Loop } - pool.Run(func(ls *lua.LState) []*kgo.Record { + pool.Run(func(ls *lua.LState) poolResult { items := model.NewRecordItems(len(qrecord)) for i, val := range qrecord { items.AddColumn(schema.Fields[i].Name, val) @@ -61,7 +61,7 @@ Loop: fn, ok := lfn.(*lua.LFunction) if !ok { queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) - return nil + return poolResult{} } ls.Push(fn) @@ -69,7 +69,7 @@ Loop: err := ls.PCall(1, -1, nil) if err != nil { queueErr(fmt.Errorf("script failed: %w", err)) - return nil + return poolResult{} } args := ls.GetTop() @@ -78,7 +78,7 @@ Loop: kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args)) if err != nil { queueErr(err) - return nil + return poolResult{} } if kr != nil { if kr.Topic == "" { @@ -89,7 +89,7 @@ Loop: } ls.SetTop(0) numRecords.Add(1) - return results + return poolResult{records: results} }) case <-queueCtx.Done(): diff --git a/flow/connectors/pubsub/pubsub.go b/flow/connectors/pubsub/pubsub.go index 0a8709b3b2..1ad3931432 100644 --- a/flow/connectors/pubsub/pubsub.go +++ b/flow/connectors/pubsub/pubsub.go @@ -76,6 +76,16 @@ type PubSubMessage struct { Topic string } +type poolResult struct { + messages []PubSubMessage + lsn int64 +} + +type publishResult struct { + *pubsub.PublishResult + lsn int64 +} + func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (PubSubMessage, error) { var topic string var msg *pubsub.Message @@ -125,9 +135,9 @@ func (c *PubSubConnector) createPool( script string, flowJobName string, topiccache *topicCache, - publish chan<- *pubsub.PublishResult, + publish chan<- publishResult, queueErr func(error), -) (*utils.LPool[[]PubSubMessage], error) { +) (*utils.LPool[poolResult], error) { return utils.LuaPool(func() (*lua.LState, error) { ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) { _ = c.LogFlowInfo(ctx, flowJobName, s) @@ -139,10 +149,14 @@ func (c *PubSubConnector) createPool( ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord)) } return ls, nil - }, func(messages []PubSubMessage) { - for _, message := range messages { + }, func(result poolResult) { + for _, message := range result.messages { topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) { topicClient := c.client.Topic(message.Topic) + if message.OrderingKey != "" { + topicClient.EnableMessageOrdering = true + } + exists, err := topicClient.Exists(ctx) if err != nil { return nil, fmt.Errorf("error checking if topic exists: %w", err) @@ -160,7 +174,12 @@ func (c *PubSubConnector) createPool( return } - publish <- topicClient.Publish(ctx, message.Message) + publish <- publishResult{ + PublishResult: topicClient.Publish(ctx, message.Message), + } + } + publish <- publishResult{ + lsn: result.lsn, } }) } @@ -216,9 +235,10 @@ func (tc *topicCache) GetOrSet(topic string, f func() (*pubsub.Topic, error)) (* func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) { numRecords := atomic.Int64{} + lastSeenLSN := atomic.Int64{} tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings) topiccache := topicCache{cache: make(map[string]*pubsub.Topic)} - publish := make(chan *pubsub.PublishResult, 32) + publish := make(chan publishResult, 32) waitChan := make(chan struct{}) queueCtx, queueErr := context.WithCancelCause(ctx) @@ -230,7 +250,9 @@ func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecord go func() { for curpub := range publish { - if _, err := curpub.Get(ctx); err != nil { + if curpub.PublishResult == nil { + shared.AtomicInt64Max(&lastSeenLSN, curpub.lsn) + } else if _, err := curpub.Get(ctx); err != nil { queueErr(err) break } @@ -238,7 +260,6 @@ func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecord close(waitChan) }() - lastSeenLSN := atomic.Int64{} flushLoopDone := make(chan struct{}) go func() { ticker := time.NewTicker(peerdbenv.PeerDBQueueFlushTimeoutSeconds()) @@ -274,12 +295,12 @@ Loop: break Loop } - pool.Run(func(ls *lua.LState) []PubSubMessage { + pool.Run(func(ls *lua.LState) poolResult { lfn := ls.Env.RawGetString("onRecord") fn, ok := lfn.(*lua.LFunction) if !ok { queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) - return nil + return poolResult{} } ls.Push(fn) @@ -287,7 +308,7 @@ Loop: err := ls.PCall(1, -1, nil) if err != nil { queueErr(fmt.Errorf("script failed: %w", err)) - return nil + return poolResult{} } args := ls.GetTop() @@ -296,7 +317,7 @@ Loop: msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args)) if err != nil { queueErr(err) - return nil + return poolResult{} } if msg.Message != nil { if msg.Topic == "" { @@ -308,8 +329,10 @@ Loop: } ls.SetTop(0) numRecords.Add(1) - shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID()) - return results + return poolResult{ + messages: results, + lsn: record.GetCheckpointID(), + } }) case <-queueCtx.Done(): diff --git a/flow/connectors/pubsub/qrep.go b/flow/connectors/pubsub/qrep.go index c1f21edc4a..139b3ca210 100644 --- a/flow/connectors/pubsub/qrep.go +++ b/flow/connectors/pubsub/qrep.go @@ -28,7 +28,7 @@ func (c *PubSubConnector) SyncQRepRecords( numRecords := atomic.Int64{} schema := stream.Schema() topiccache := topicCache{cache: make(map[string]*pubsub.Topic)} - publish := make(chan *pubsub.PublishResult, 32) + publish := make(chan publishResult, 32) waitChan := make(chan struct{}) queueCtx, queueErr := context.WithCancelCause(ctx) @@ -40,9 +40,11 @@ func (c *PubSubConnector) SyncQRepRecords( go func() { for curpub := range publish { - if _, err := curpub.Get(ctx); err != nil { - queueErr(err) - break + if curpub.PublishResult != nil { + if _, err := curpub.Get(ctx); err != nil { + queueErr(err) + break + } } } close(waitChan) @@ -57,7 +59,7 @@ Loop: break Loop } - pool.Run(func(ls *lua.LState) []PubSubMessage { + pool.Run(func(ls *lua.LState) poolResult { items := model.NewRecordItems(len(qrecord)) for i, val := range qrecord { items.AddColumn(schema.Fields[i].Name, val) @@ -74,7 +76,7 @@ Loop: fn, ok := lfn.(*lua.LFunction) if !ok { queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn)) - return nil + return poolResult{} } ls.Push(fn) @@ -82,7 +84,7 @@ Loop: err := ls.PCall(1, -1, nil) if err != nil { queueErr(fmt.Errorf("script failed: %w", err)) - return nil + return poolResult{} } args := ls.GetTop() @@ -91,7 +93,7 @@ Loop: msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args)) if err != nil { queueErr(err) - return nil + return poolResult{} } if msg.Message != nil { if msg.Topic == "" { @@ -102,7 +104,7 @@ Loop: } ls.SetTop(0) numRecords.Add(1) - return results + return poolResult{messages: results} }) case <-queueCtx.Done():