Skip to content

Commit

Permalink
kgo group: switch memberID and generation to atomics
Browse files Browse the repository at this point in the history
Previously these required a mutex on write and read because of the rare
(and mostly erroneous)  chance that a person is committing during a
rebalance.

Well, that makes the next transactional commit harder and is overkill --
switching to atomics doesn't change any correctness bit but allows us to
worry about deadlocks just a bit less.
  • Loading branch information
twmb committed Oct 21, 2023
1 parent 6a961da commit 39e28c0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 47 deletions.
113 changes: 68 additions & 45 deletions pkg/kgo/consumer_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,12 @@ type groupConsumer struct {
uncommitted uncommitted

// memberID and generation are written to in the join and sync loop,
// and mostly read within that loop. The reason these two are under the
// mutex is because they are read during commits, which can happen at
// any arbitrary moment. It is **recommended** to be done within the
// context of a group session, but (a) users may have some unique use
// cases, and (b) the onRevoke hook may take longer than a user
// and mostly read within that loop. This can be read during commits,
// which can happy any time. It is **recommended** to be done within
// the context of a group session, but (a) users may have some unique
// use cases, and (b) the onRevoke hook may take longer than a user
// expects, which would rotate a session.
memberID string
generation int32
memberGen groupMemberGen

// commitCancel and commitDone are set under mu before firing off an
// async commit request. If another commit happens, it cancels the
Expand All @@ -155,6 +153,42 @@ type groupConsumer struct {
leaveErr error // set before left is closed
}

type groupMemberGen struct {
v atomic.Value // *groupMemberGenT
}

type groupMemberGenT struct {
memberID string
generation int32
}

func (g *groupMemberGen) memberID() string {
memberID, _ := g.load()
return memberID
}

func (g *groupMemberGen) generation() int32 {
_, generation := g.load()
return generation
}

func (g *groupMemberGen) load() (memberID string, generation int32) {
v := g.v.Load()
if v == nil {
return "", -1
}
t := v.(*groupMemberGenT)
return t.memberID, t.generation
}

func (g *groupMemberGen) store(memberID string, generation int32) {
g.v.Store(&groupMemberGenT{memberID, generation})
}

func (g *groupMemberGen) storeMember(memberID string) {
g.store(memberID, g.generation())
}

// LeaveGroup leaves a group. Close automatically leaves the group, so this is
// only necessary to call if you plan to leave the group but continue to use
// the client. If a rebalance is in progress, this function waits for the
Expand Down Expand Up @@ -235,12 +269,7 @@ func (cl *Client) GroupMetadata() (string, int32) {
if g == nil {
return "", -1
}
g.mu.Lock()
defer g.mu.Unlock()
if g.memberID == "" {
return "", -1
}
return g.memberID, g.generation
return g.memberGen.load()
}

func (c *consumer) initGroup() {
Expand Down Expand Up @@ -488,17 +517,18 @@ func (g *groupConsumer) leave(ctx context.Context) {
return
}

memberID := g.memberGen.memberID()
g.cfg.logger.Log(LogLevelInfo, "leaving group",
"group", g.cfg.group,
"member_id", g.memberID, // lock not needed now since nothing can change it (manageDone)
"member_id", memberID,
)
// If we error when leaving, there is not much
// we can do. We may as well just return.
req := kmsg.NewPtrLeaveGroupRequest()
req.Group = g.cfg.group
req.MemberID = g.memberID
req.MemberID = memberID
member := kmsg.NewLeaveGroupRequestMember()
member.MemberID = g.memberID
member.MemberID = memberID
member.Reason = kmsg.StringPtr("client leaving group per normal operation")
req.Members = append(req.Members, member)

Expand Down Expand Up @@ -940,8 +970,9 @@ func (g *groupConsumer) heartbeat(fetchErrCh <-chan error, s *assignRevokeSessio
g.cfg.logger.Log(LogLevelDebug, "heartbeating", "group", g.cfg.group)
req := kmsg.NewPtrHeartbeatRequest()
req.Group = g.cfg.group
req.Generation = g.generation
req.MemberID = g.memberID
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID
var resp *kmsg.HeartbeatResponse
if resp, err = req.RequestWith(g.ctx, g.cl); err == nil {
Expand Down Expand Up @@ -1075,7 +1106,7 @@ start:
joinReq.SessionTimeoutMillis = int32(g.cfg.sessionTimeout.Milliseconds())
joinReq.RebalanceTimeoutMillis = int32(g.cfg.rebalanceTimeout.Milliseconds())
joinReq.ProtocolType = g.cfg.protocol
joinReq.MemberID = g.memberID
joinReq.MemberID = g.memberGen.memberID()
joinReq.InstanceID = g.cfg.instanceID
joinReq.Protocols = g.joinGroupProtocols()
if joinWhy != "" {
Expand Down Expand Up @@ -1120,8 +1151,9 @@ start:

syncReq := kmsg.NewPtrSyncGroupRequest()
syncReq.Group = g.cfg.group
syncReq.Generation = g.generation
syncReq.MemberID = g.memberID
memberID, generation := g.memberGen.load()
syncReq.Generation = generation
syncReq.MemberID = memberID
syncReq.InstanceID = g.cfg.instanceID
syncReq.ProtocolType = &g.cfg.protocol
syncReq.Protocol = &protocol
Expand Down Expand Up @@ -1168,7 +1200,7 @@ start:
// and must trigger a rebalance.
if plan != nil && joinResp.SkipAssignment {
for _, assign := range plan {
if assign.MemberID == g.memberID {
if assign.MemberID == memberID {
if !bytes.Equal(assign.MemberAssignment, syncResp.MemberAssignment) {
g.rejoin("instance group leader restarted and was reassigned old plan, our topic interests changed and we must rejoin to force a rebalance")
}
Expand All @@ -1184,27 +1216,17 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
if err = kerr.ErrorForCode(resp.ErrorCode); err != nil {
switch err {
case kerr.MemberIDRequired:
g.mu.Lock()
g.memberID = resp.MemberID // KIP-394
g.mu.Unlock()
g.memberGen.storeMember(resp.MemberID) // KIP-394
g.cfg.logger.Log(LogLevelInfo, "join returned MemberIDRequired, rejoining with response's MemberID", "group", g.cfg.group, "member_id", resp.MemberID)
return true, "", nil, nil
case kerr.UnknownMemberID:
g.mu.Lock()
g.memberID = ""
g.mu.Unlock()
g.memberGen.storeMember("")
g.cfg.logger.Log(LogLevelInfo, "join returned UnknownMemberID, rejoining without a member id", "group", g.cfg.group)
return true, "", nil, nil
}
return // Request retries as necessary, so this must be a failure
}

// Concurrent committing, while erroneous to do at the moment, could
// race with this function. We need to lock setting these two fields.
g.mu.Lock()
g.memberID = resp.MemberID
g.generation = resp.Generation
g.mu.Unlock()
g.memberGen.store(resp.MemberID, resp.Generation)

if resp.Protocol != nil {
protocol = *resp.Protocol
Expand Down Expand Up @@ -1252,9 +1274,9 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
g.leader.Store(true)
g.cfg.logger.Log(LogLevelInfo, "joined, balancing group",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"balance_protocol", protocol,
"leader", true,
)
Expand All @@ -1263,18 +1285,18 @@ func (g *groupConsumer) handleJoinResp(resp *kmsg.JoinGroupResponse) (restart bo
g.leader.Store(true)
g.cfg.logger.Log(LogLevelInfo, "joined as leader but unable to balance group due to KIP-345 limitations",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"balance_protocol", protocol,
"leader", true,
)
} else {
g.cfg.logger.Log(LogLevelInfo, "joined",
"group", g.cfg.group,
"member_id", g.memberID,
"member_id", resp.MemberID,
"instance_id", strptr{g.cfg.instanceID},
"generation", g.generation,
"generation", resp.Generation,
"leader", false,
)
}
Expand Down Expand Up @@ -1427,7 +1449,6 @@ func (g *groupConsumer) joinGroupProtocols() []kmsg.JoinGroupRequestProtocol {
for t, ps := range g.lastAssigned {
lastDup[t] = append([]int32(nil), ps...) // deep copy to allow modifications
}
gen := g.generation

g.mu.Unlock()

Expand All @@ -1436,6 +1457,7 @@ func (g *groupConsumer) joinGroupProtocols() []kmsg.JoinGroupRequestProtocol {
sort.Slice(partitions, func(i, j int) bool { return partitions[i] < partitions[j] }) // same for partitions
}

gen := g.memberGen.generation()
var protos []kmsg.JoinGroupRequestProtocol
for _, balancer := range g.cfg.balancers {
proto := kmsg.NewJoinGroupRequestProtocol()
Expand Down Expand Up @@ -1931,7 +1953,7 @@ func (g *groupConsumer) updateCommitted(
g.mu.Lock()
defer g.mu.Unlock()

if req.Generation != g.generation {
if req.Generation != g.memberGen.generation() {
return
}
if g.uncommitted == nil {
Expand Down Expand Up @@ -2764,8 +2786,9 @@ func (g *groupConsumer) commit(

req := kmsg.NewPtrOffsetCommitRequest()
req.Group = g.cfg.group
req.Generation = g.generation
req.MemberID = g.memberID
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID

if ctx.Done() != nil {
Expand Down
5 changes: 3 additions & 2 deletions pkg/kgo/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1180,8 +1180,9 @@ func (g *groupConsumer) commitTxn(
req.Group = g.cfg.group
req.ProducerID = id
req.ProducerEpoch = epoch
req.Generation = g.generation
req.MemberID = g.memberID
memberID, generation := g.memberGen.load()
req.Generation = generation
req.MemberID = memberID
req.InstanceID = g.cfg.instanceID

if ctx.Done() != nil {
Expand Down

0 comments on commit 39e28c0

Please sign in to comment.