Skip to content

Commit

Permalink
kafka/pubsub: fix LSN potentially being updated too early (#1709)
Browse files Browse the repository at this point in the history
LSN should not be updated before success confirmed,
as intermediate value may now be read before error aborts when queue is flushed

With parallelism, lsn should be updated in critical section so that if earlier invocation is lagging behind LSN doesn't skip

Also fix pubsub needing EnableMessageOrdering explicitly enabled
  • Loading branch information
serprex authored May 10, 2024
1 parent 0116f5e commit 5f904dd
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 48 deletions.
55 changes: 36 additions & 19 deletions flow/connectors/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,18 @@ func lvalueToKafkaRecord(ls *lua.LState, value lua.LValue) (*kgo.Record, error)
return kr, nil
}

type poolResult struct {
records []*kgo.Record
lsn int64
}

func (c *KafkaConnector) createPool(
ctx context.Context,
script string,
flowJobName string,
lastSeenLSN *atomic.Int64,
queueErr func(error),
) (*utils.LPool[[]*kgo.Record], error) {
produceCb := func(_ *kgo.Record, err error) {
if err != nil {
queueErr(err)
}
}

) (*utils.LPool[poolResult], error) {
return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) {
_ = c.LogFlowInfo(ctx, flowJobName, s)
Expand All @@ -187,25 +187,40 @@ func (c *KafkaConnector) createPool(
ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord))
}
return ls, nil
}, func(krs []*kgo.Record) {
for _, kr := range krs {
c.client.Produce(ctx, kr, produceCb)
}, func(result poolResult) {
lenRecords := int32(len(result.records))
if lenRecords == 0 {
if lastSeenLSN != nil {
shared.AtomicInt64Max(lastSeenLSN, result.lsn)
}
} else {
recordCounter := atomic.Int32{}
recordCounter.Store(lenRecords)
for _, kr := range result.records {
c.client.Produce(ctx, kr, func(_ *kgo.Record, err error) {
if err != nil {
queueErr(err)
} else if recordCounter.Add(-1) == 0 && lastSeenLSN != nil {
shared.AtomicInt64Max(lastSeenLSN, result.lsn)
}
})
}
}
})
}

func (c *KafkaConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) {
numRecords := atomic.Int64{}
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)
lastSeenLSN := atomic.Int64{}

queueCtx, queueErr := context.WithCancelCause(ctx)
pool, err := c.createPool(queueCtx, req.Script, req.FlowJobName, queueErr)
pool, err := c.createPool(queueCtx, req.Script, req.FlowJobName, &lastSeenLSN, queueErr)
if err != nil {
return nil, err
}
defer pool.Close()

lastSeenLSN := atomic.Int64{}
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)
flushLoopDone := make(chan struct{})
go func() {
ticker := time.NewTicker(peerdbenv.PeerDBQueueFlushTimeoutSeconds())
Expand Down Expand Up @@ -244,20 +259,20 @@ Loop:
break Loop
}

pool.Run(func(ls *lua.LState) []*kgo.Record {
pool.Run(func(ls *lua.LState) poolResult {
lfn := ls.Env.RawGetString("onRecord")
fn, ok := lfn.(*lua.LFunction)
if !ok {
queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
return poolResult{}
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
queueErr(fmt.Errorf("script failed: %w", err))
return nil
return poolResult{}
}

args := ls.GetTop()
Expand All @@ -266,7 +281,7 @@ Loop:
kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args))
if err != nil {
queueErr(err)
return nil
return poolResult{}
}
if kr != nil {
if kr.Topic == "" {
Expand All @@ -278,8 +293,10 @@ Loop:
}
ls.SetTop(0)
numRecords.Add(1)
shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID())
return results
return poolResult{
records: results,
lsn: record.GetCheckpointID(),
}
})

case <-queueCtx.Done():
Expand Down
12 changes: 6 additions & 6 deletions flow/connectors/kafka/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (c *KafkaConnector) SyncQRepRecords(
schema := stream.Schema()

queueCtx, queueErr := context.WithCancelCause(ctx)
pool, err := c.createPool(queueCtx, config.Script, config.FlowJobName, queueErr)
pool, err := c.createPool(queueCtx, config.Script, config.FlowJobName, nil, queueErr)
if err != nil {
return 0, err
}
Expand All @@ -44,7 +44,7 @@ Loop:
break Loop
}

pool.Run(func(ls *lua.LState) []*kgo.Record {
pool.Run(func(ls *lua.LState) poolResult {
items := model.NewRecordItems(len(qrecord))
for i, val := range qrecord {
items.AddColumn(schema.Fields[i].Name, val)
Expand All @@ -61,15 +61,15 @@ Loop:
fn, ok := lfn.(*lua.LFunction)
if !ok {
queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
return poolResult{}
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
queueErr(fmt.Errorf("script failed: %w", err))
return nil
return poolResult{}
}

args := ls.GetTop()
Expand All @@ -78,7 +78,7 @@ Loop:
kr, err := lvalueToKafkaRecord(ls, ls.Get(i-args))
if err != nil {
queueErr(err)
return nil
return poolResult{}
}
if kr != nil {
if kr.Topic == "" {
Expand All @@ -89,7 +89,7 @@ Loop:
}
ls.SetTop(0)
numRecords.Add(1)
return results
return poolResult{records: results}
})

case <-queueCtx.Done():
Expand Down
51 changes: 37 additions & 14 deletions flow/connectors/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ type PubSubMessage struct {
Topic string
}

type poolResult struct {
messages []PubSubMessage
lsn int64
}

type publishResult struct {
*pubsub.PublishResult
lsn int64
}

func lvalueToPubSubMessage(ls *lua.LState, value lua.LValue) (PubSubMessage, error) {
var topic string
var msg *pubsub.Message
Expand Down Expand Up @@ -125,9 +135,9 @@ func (c *PubSubConnector) createPool(
script string,
flowJobName string,
topiccache *topicCache,
publish chan<- *pubsub.PublishResult,
publish chan<- publishResult,
queueErr func(error),
) (*utils.LPool[[]PubSubMessage], error) {
) (*utils.LPool[poolResult], error) {
return utils.LuaPool(func() (*lua.LState, error) {
ls, err := utils.LoadScript(ctx, script, utils.LuaPrintFn(func(s string) {
_ = c.LogFlowInfo(ctx, flowJobName, s)
Expand All @@ -139,10 +149,14 @@ func (c *PubSubConnector) createPool(
ls.Env.RawSetString("onRecord", ls.NewFunction(utils.DefaultOnRecord))
}
return ls, nil
}, func(messages []PubSubMessage) {
for _, message := range messages {
}, func(result poolResult) {
for _, message := range result.messages {
topicClient, err := topiccache.GetOrSet(message.Topic, func() (*pubsub.Topic, error) {
topicClient := c.client.Topic(message.Topic)
if message.OrderingKey != "" {
topicClient.EnableMessageOrdering = true
}

exists, err := topicClient.Exists(ctx)
if err != nil {
return nil, fmt.Errorf("error checking if topic exists: %w", err)
Expand All @@ -160,7 +174,12 @@ func (c *PubSubConnector) createPool(
return
}

publish <- topicClient.Publish(ctx, message.Message)
publish <- publishResult{
PublishResult: topicClient.Publish(ctx, message.Message),
}
}
publish <- publishResult{
lsn: result.lsn,
}
})
}
Expand Down Expand Up @@ -216,9 +235,10 @@ func (tc *topicCache) GetOrSet(topic string, f func() (*pubsub.Topic, error)) (*

func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecordsRequest[model.RecordItems]) (*model.SyncResponse, error) {
numRecords := atomic.Int64{}
lastSeenLSN := atomic.Int64{}
tableNameRowsMapping := utils.InitialiseTableRowsMap(req.TableMappings)
topiccache := topicCache{cache: make(map[string]*pubsub.Topic)}
publish := make(chan *pubsub.PublishResult, 32)
publish := make(chan publishResult, 32)
waitChan := make(chan struct{})

queueCtx, queueErr := context.WithCancelCause(ctx)
Expand All @@ -230,15 +250,16 @@ func (c *PubSubConnector) SyncRecords(ctx context.Context, req *model.SyncRecord

go func() {
for curpub := range publish {
if _, err := curpub.Get(ctx); err != nil {
if curpub.PublishResult == nil {
shared.AtomicInt64Max(&lastSeenLSN, curpub.lsn)
} else if _, err := curpub.Get(ctx); err != nil {
queueErr(err)
break
}
}
close(waitChan)
}()

lastSeenLSN := atomic.Int64{}
flushLoopDone := make(chan struct{})
go func() {
ticker := time.NewTicker(peerdbenv.PeerDBQueueFlushTimeoutSeconds())
Expand Down Expand Up @@ -274,20 +295,20 @@ Loop:
break Loop
}

pool.Run(func(ls *lua.LState) []PubSubMessage {
pool.Run(func(ls *lua.LState) poolResult {
lfn := ls.Env.RawGetString("onRecord")
fn, ok := lfn.(*lua.LFunction)
if !ok {
queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
return poolResult{}
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
queueErr(fmt.Errorf("script failed: %w", err))
return nil
return poolResult{}
}

args := ls.GetTop()
Expand All @@ -296,7 +317,7 @@ Loop:
msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args))
if err != nil {
queueErr(err)
return nil
return poolResult{}
}
if msg.Message != nil {
if msg.Topic == "" {
Expand All @@ -308,8 +329,10 @@ Loop:
}
ls.SetTop(0)
numRecords.Add(1)
shared.AtomicInt64Max(&lastSeenLSN, record.GetCheckpointID())
return results
return poolResult{
messages: results,
lsn: record.GetCheckpointID(),
}
})

case <-queueCtx.Done():
Expand Down
20 changes: 11 additions & 9 deletions flow/connectors/pubsub/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (c *PubSubConnector) SyncQRepRecords(
numRecords := atomic.Int64{}
schema := stream.Schema()
topiccache := topicCache{cache: make(map[string]*pubsub.Topic)}
publish := make(chan *pubsub.PublishResult, 32)
publish := make(chan publishResult, 32)
waitChan := make(chan struct{})

queueCtx, queueErr := context.WithCancelCause(ctx)
Expand All @@ -40,9 +40,11 @@ func (c *PubSubConnector) SyncQRepRecords(

go func() {
for curpub := range publish {
if _, err := curpub.Get(ctx); err != nil {
queueErr(err)
break
if curpub.PublishResult != nil {
if _, err := curpub.Get(ctx); err != nil {
queueErr(err)
break
}
}
}
close(waitChan)
Expand All @@ -57,7 +59,7 @@ Loop:
break Loop
}

pool.Run(func(ls *lua.LState) []PubSubMessage {
pool.Run(func(ls *lua.LState) poolResult {
items := model.NewRecordItems(len(qrecord))
for i, val := range qrecord {
items.AddColumn(schema.Fields[i].Name, val)
Expand All @@ -74,15 +76,15 @@ Loop:
fn, ok := lfn.(*lua.LFunction)
if !ok {
queueErr(fmt.Errorf("script should define `onRecord` as function, not %s", lfn))
return nil
return poolResult{}
}

ls.Push(fn)
ls.Push(pua.LuaRecord.New(ls, record))
err := ls.PCall(1, -1, nil)
if err != nil {
queueErr(fmt.Errorf("script failed: %w", err))
return nil
return poolResult{}
}

args := ls.GetTop()
Expand All @@ -91,7 +93,7 @@ Loop:
msg, err := lvalueToPubSubMessage(ls, ls.Get(i-args))
if err != nil {
queueErr(err)
return nil
return poolResult{}
}
if msg.Message != nil {
if msg.Topic == "" {
Expand All @@ -102,7 +104,7 @@ Loop:
}
ls.SetTop(0)
numRecords.Add(1)
return results
return poolResult{messages: results}
})

case <-queueCtx.Done():
Expand Down

0 comments on commit 5f904dd

Please sign in to comment.