diff --git a/pkg/kgo/consumer.go b/pkg/kgo/consumer.go index 475270ed..9fdf9bde 100644 --- a/pkg/kgo/consumer.go +++ b/pkg/kgo/consumer.go @@ -433,6 +433,8 @@ func (cl *Client) PollRecords(ctx context.Context, maxPollRecords int) Fetches { }() } + paused := c.loadPaused() + // A group can grab the consumer lock then the group mu and // assign partitions. The group mu is grabbed to update its // uncommitted map. Assigning partitions clears sources ready @@ -451,13 +453,13 @@ func (cl *Client) PollRecords(ctx context.Context, maxPollRecords int) Fetches { c.sourcesReadyMu.Lock() if maxPollRecords < 0 { for _, ready := range c.sourcesReadyForDraining { - fetches = append(fetches, ready.takeBuffered()) + fetches = append(fetches, ready.takeBuffered(paused)) } c.sourcesReadyForDraining = nil } else { for len(c.sourcesReadyForDraining) > 0 && maxPollRecords > 0 { source := c.sourcesReadyForDraining[0] - fetch, taken, drained := source.takeNBuffered(maxPollRecords) + fetch, taken, drained := source.takeNBuffered(paused, maxPollRecords) if drained { c.sourcesReadyForDraining = c.sourcesReadyForDraining[1:] } @@ -555,9 +557,7 @@ func (cl *Client) UpdateFetchMaxBytes(maxBytes, maxPartBytes int32) { // PauseFetchTopics sets the client to no longer fetch the given topics and // returns all currently paused topics. Paused topics persist until resumed. // You can call this function with no topics to simply receive the list of -// currently paused topics. Pausing topics drops everything currently buffered -// and kills any in flight fetch requests to ensure nothing that is paused -// can be returned anymore from polling. +// currently paused topics. // // Pausing topics is independent from pausing individual partitions with the // PauseFetchPartitions method. If you pause partitions for a topic with @@ -569,15 +569,8 @@ func (cl *Client) PauseFetchTopics(topics ...string) []string { if len(topics) == 0 { return c.loadPaused().pausedTopics() } - c.pausedMu.Lock() defer c.pausedMu.Unlock() - defer func() { - c.mu.Lock() - defer c.mu.Unlock() - c.assignPartitions(nil, assignBumpSession, nil, fmt.Sprintf("pausing fetch topics %v", topics)) - }() - paused := c.clonePaused() paused.addTopics(topics...) c.storePaused(paused) @@ -587,9 +580,7 @@ func (cl *Client) PauseFetchTopics(topics ...string) []string { // PauseFetchPartitions sets the client to no longer fetch the given partitions // and returns all currently paused partitions. Paused partitions persist until // resumed. You can call this function with no partitions to simply receive the -// list of currently paused partitions. Pausing partitions drops everything -// currently buffered and kills any in flight fetch requests to ensure nothing -// that is paused can be returned anymore from polling. +// list of currently paused partitions. // // Pausing individual partitions is independent from pausing topics with the // PauseFetchTopics method. If you pause partitions for a topic with @@ -601,15 +592,8 @@ func (cl *Client) PauseFetchPartitions(topicPartitions map[string][]int32) map[s if len(topicPartitions) == 0 { return c.loadPaused().pausedPartitions() } - c.pausedMu.Lock() defer c.pausedMu.Unlock() - defer func() { - c.mu.Lock() - defer c.mu.Unlock() - c.assignPartitions(nil, assignBumpSession, nil, fmt.Sprintf("pausing fetch partitions %v", topicPartitions)) - }() - paused := c.clonePaused() paused.addPartitions(topicPartitions) c.storePaused(paused) @@ -884,10 +868,6 @@ const ( // The counterpart to assignInvalidateMatching, assignSetMatching // resets all matching partitions to the specified offset / epoch. assignSetMatching - - // For pausing, we want to drop anything inflight. We start a new - // session with the old tps. - assignBumpSession ) func (h assignHow) String() string { @@ -902,8 +882,6 @@ func (h assignHow) String() string { return "unassigning and purging any partition matching the input topics" case assignSetMatching: return "reassigning any currently assigned matching partition to the input" - case assignBumpSession: - return "bumping internal consumer session to drop anything currently in flight" } return "" } @@ -984,8 +962,6 @@ func (c *consumer) assignPartitions(assignments map[string]map[int32]Offset, how // if we had no session before, which is why we need to pass in // our topicPartitions. session = c.guardSessionChange(tps) - } else if how == assignBumpSession { - loadOffsets, tps = c.stopSession() } else { loadOffsets, _ = c.stopSession() @@ -1032,7 +1008,7 @@ func (c *consumer) assignPartitions(assignments map[string]map[int32]Offset, how // assignment went straight to listing / epoch loading, and // that list/epoch never finished. switch how { - case assignWithoutInvalidating, assignBumpSession: + case assignWithoutInvalidating: // Nothing to do -- this is handled above. case assignInvalidateAll: loadOffsets = listOrEpochLoads{} diff --git a/pkg/kgo/consumer_direct_test.go b/pkg/kgo/consumer_direct_test.go index 548ab575..ac385e8e 100644 --- a/pkg/kgo/consumer_direct_test.go +++ b/pkg/kgo/consumer_direct_test.go @@ -307,8 +307,13 @@ func TestPauseIssue489(t *testing.T) { } cl.PauseFetchPartitions(map[string][]int32{t1: {0}}) sawZero, sawOne = false, false - for i := 0; i < 5; i++ { - fs := cl.PollFetches(ctx) + for i := 0; i < 10; i++ { + var fs Fetches + if i < 5 { + fs = cl.PollFetches(ctx) + } else { + fs = cl.PollRecords(ctx, 2) + } fs.EachRecord(func(r *Record) { sawZero = sawZero || r.Partition == 0 sawOne = sawOne || r.Partition == 1 diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 226ef458..427f2091 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -344,8 +344,71 @@ func (s *source) hook(f *Fetch, buffered, polled bool) { } // takeBuffered drains a buffered fetch and updates offsets. -func (s *source) takeBuffered() Fetch { - return s.takeBufferedFn(true, usedOffsets.finishUsingAllWithSet) +func (s *source) takeBuffered(paused pausedTopics) Fetch { + if len(paused) == 0 { + return s.takeBufferedFn(true, usedOffsets.finishUsingAllWithSet) + } + var strip map[string]map[int32]struct{} + f := s.takeBufferedFn(true, func(os usedOffsets) { + for t, ps := range os { + // If the entire topic is paused, we allowUsable all + // and strip the topic entirely. + if paused.has(t, -1) { + for _, o := range ps { + o.from.allowUsable() + } + if strip == nil { + strip = make(map[string]map[int32]struct{}) + } + strip[t] = nil // initialize key, for existence-but-len-0 check below + continue + } + // If any partition is paused, we allowUsable it and + // strip it from the topic. The logic is a bit weird + // to try to avoid initialization (alloc) if possible. + var stript map[int32]struct{} + for _, o := range ps { + if paused.has(o.from.topic, o.from.partition) { + o.from.allowUsable() + if stript == nil { + stript = make(map[int32]struct{}) + } + stript[o.from.partition] = struct{}{} + continue + } + o.from.setOffset(o.cursorOffset) + o.from.allowUsable() + } + if stript != nil { + if strip == nil { + strip = make(map[string]map[int32]struct{}) + } + strip[t] = stript + } + } + }) + if strip != nil { + keep := f.Topics[:0] + for _, t := range f.Topics { + stript, ok := strip[t.Topic] + if ok { + if len(stript) == 0 { + continue // stripping this entire topic + } + keepp := t.Partitions[:0] + for _, p := range t.Partitions { + if _, ok := stript[p.Partition]; ok { + continue + } + keepp = append(keepp, p) + } + t.Partitions = keepp + } + keep = append(keep, t) + } + f.Topics = keep + } + return f } func (s *source) discardBuffered() { @@ -359,7 +422,7 @@ func (s *source) discardBuffered() { // // This returns the number of records taken and whether the source has been // completely drained. -func (s *source) takeNBuffered(n int) (Fetch, int, bool) { +func (s *source) takeNBuffered(paused pausedTopics, n int) (Fetch, int, bool) { var r Fetch var taken int @@ -368,6 +431,17 @@ func (s *source) takeNBuffered(n int) (Fetch, int, bool) { for len(bf.Topics) > 0 && n > 0 { t := &bf.Topics[0] + // If the topic is outright paused, we allowUsable all + // partitions in the topic and skip the topic entirely. + if paused.has(t.Topic, -1) { + bf.Topics = bf.Topics[1:] + for _, pCursor := range b.usedOffsets[t.Topic] { + pCursor.from.allowUsable() + } + delete(b.usedOffsets, t.Topic) + continue + } + r.Topics = append(r.Topics, *t) rt := &r.Topics[len(r.Topics)-1] rt.Partitions = nil @@ -377,6 +451,17 @@ func (s *source) takeNBuffered(n int) (Fetch, int, bool) { for len(t.Partitions) > 0 && n > 0 { p := &t.Partitions[0] + if paused.has(t.Topic, p.Partition) { + t.Partitions = t.Partitions[1:] + pCursor := tCursors[p.Partition] + pCursor.from.allowUsable() + delete(tCursors, p.Partition) + if len(tCursors) == 0 { + delete(b.usedOffsets, t.Topic) + } + continue + } + rt.Partitions = append(rt.Partitions, *p) rp := &rt.Partitions[len(rt.Partitions)-1] @@ -402,7 +487,7 @@ func (s *source) takeNBuffered(n int) (Fetch, int, bool) { if len(tCursors) == 0 { delete(b.usedOffsets, t.Topic) } - break + continue } lastReturnedRecord := rp.Records[len(rp.Records)-1] @@ -422,7 +507,7 @@ func (s *source) takeNBuffered(n int) (Fetch, int, bool) { drained := len(bf.Topics) == 0 if drained { - s.takeBuffered() + s.takeBuffered(nil) } return r, taken, drained }