From 4ffb815692e2b4e88ce28412b5227db8bd237fa9 Mon Sep 17 00:00:00 2001 From: Salim Afiune Maya Date: Sat, 14 Dec 2024 22:46:08 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=99=EF=B8=8F=20=20return=20`pool.Result`?= =?UTF-8?q?=20as=20a=20combined=20struct?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Salim Afiune Maya --- explorer/scan/discovery.go | 4 +- internal/workerpool/collector.go | 38 +++++++------- internal/workerpool/pool.go | 64 ++++++++++++++++-------- internal/workerpool/pool_test.go | 50 +++++++++++------- internal/workerpool/worker.go | 15 +----- providers/github/resources/github_org.go | 9 ++-- 6 files changed, 106 insertions(+), 74 deletions(-) diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go index d645b7c09..994b9f402 100644 --- a/explorer/scan/discovery.go +++ b/explorer/scan/discovery.go @@ -188,7 +188,9 @@ func discoverAssets(rootAssetWithRuntime *AssetWithRuntime, resolvedRootAsset *i pool.Wait() // Get all assets with runtimes from the pool - for _, assetWithRuntime := range pool.GetResults() { + for _, result := range pool.GetResults() { + assetWithRuntime := result.Value + // If asset is nil, then we observed a duplicate asset with a // runtime that already exists. if assetWithRuntime == nil { diff --git a/internal/workerpool/collector.go b/internal/workerpool/collector.go index 2d105501b..bb33bb836 100644 --- a/internal/workerpool/collector.go +++ b/internal/workerpool/collector.go @@ -9,13 +9,11 @@ import ( ) type collector[R any] struct { - resultsCh <-chan R - results []R + resultsCh <-chan Result[R] + results []Result[R] read sync.Mutex - errorsCh <-chan error - errors []error - + // The total number of requests read. requestsRead int64 } @@ -27,29 +25,35 @@ func (c *collector[R]) start() { c.read.Lock() c.results = append(c.results, result) c.read.Unlock() - - case err := <-c.errorsCh: - c.read.Lock() - c.errors = append(c.errors, err) - c.read.Unlock() } atomic.AddInt64(&c.requestsRead, 1) } }() } -func (c *collector[R]) GetResults() []R { + +func (c *collector[R]) RequestsRead() int64 { + return atomic.LoadInt64(&c.requestsRead) +} + +func (c *collector[R]) GetResults() []Result[R] { c.read.Lock() defer c.read.Unlock() return c.results } -func (c *collector[R]) GetErrors() []error { - c.read.Lock() - defer c.read.Unlock() - return c.errors +func (c *collector[R]) GetValues() (slice []R) { + results := c.GetResults() + for i := range results { + slice = append(slice, results[i].Value) + } + return } -func (c *collector[R]) RequestsRead() int64 { - return atomic.LoadInt64(&c.requestsRead) +func (c *collector[R]) GetErrors() (slice []error) { + results := c.GetResults() + for i := range results { + slice = append(slice, results[i].Error) + } + return } diff --git a/internal/workerpool/pool.go b/internal/workerpool/pool.go index d59d5b009..1ad4afa86 100644 --- a/internal/workerpool/pool.go +++ b/internal/workerpool/pool.go @@ -7,25 +7,40 @@ import ( "sync" "sync/atomic" "time" - - "github.com/cockroachdb/errors" ) +// Represent the tasks that can be sent to the pool. type Task[R any] func() (result R, err error) +// The result generated from a task. +type Result[R any] struct { + Value R + Error error +} + // Pool is a generic pool of workers. type Pool[R any] struct { - queueCh chan Task[R] - resultsCh chan R - errorsCh chan error + // The queue where tasks are submitted. + queueCh chan Task[R] + // Where workers send the results after a task is executed, + // the collector then reads them and aggregate them. + resultsCh chan Result[R] + + // The total number of requests sent. requestsSent int64 - once sync.Once - workers []*worker[R] + // Number of workers to spawn. workerCount int + // The list of workers that are listening to the queue. + workers []*worker[R] + + // A single collector to aggregate results. collector[R] + + // used to protect starting the pool multiple times + once sync.Once } // New initializes a new Pool with the provided number of workers. The pool is generic and can @@ -37,14 +52,12 @@ type Pool[R any] struct { // return 42, nil // } func New[R any](count int) *Pool[R] { - resultsCh := make(chan R) - errorsCh := make(chan error) + resultsCh := make(chan Result[R]) return &Pool[R]{ queueCh: make(chan Task[R]), resultsCh: resultsCh, - errorsCh: errorsCh, workerCount: count, - collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh}, + collector: collector[R]{resultsCh: resultsCh}, } } @@ -56,7 +69,7 @@ func New[R any](count int) *Pool[R] { func (p *Pool[R]) Start() { p.once.Do(func() { for i := 0; i < p.workerCount; i++ { - w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh} + w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh} w.start() p.workers = append(p.workers, &w) } @@ -67,22 +80,33 @@ func (p *Pool[R]) Start() { // Submit sends a task to the workers func (p *Pool[R]) Submit(t Task[R]) { - p.queueCh <- t - atomic.AddInt64(&p.requestsSent, 1) -} - -// GetErrors returns any error from a processed task -func (p *Pool[R]) GetErrors() error { - return errors.Join(p.collector.GetErrors()...) + if t != nil { + p.queueCh <- t + atomic.AddInt64(&p.requestsSent, 1) + } } // GetResults returns the tasks results. // // It is recommended to call `Wait()` before reading the results. -func (p *Pool[R]) GetResults() []R { +func (p *Pool[R]) GetResults() []Result[R] { return p.collector.GetResults() } +// GetValues returns only the values of the pool results +// +// It is recommended to call `Wait()` before reading the results. +func (p *Pool[R]) GetValues() []R { + return p.collector.GetValues() +} + +// GetErrors returns only the errors of the pool results +// +// It is recommended to call `Wait()` before reading the results. +func (p *Pool[R]) GettErrors() []error { + return p.collector.GetErrors() +} + // Close waits for workers and collector to process all the requests, and then closes // the task queue channel. After closing the pool, calling `Submit()` will panic. func (p *Pool[R]) Close() { diff --git a/internal/workerpool/pool_test.go b/internal/workerpool/pool_test.go index 3b3946df1..222dad570 100644 --- a/internal/workerpool/pool_test.go +++ b/internal/workerpool/pool_test.go @@ -35,11 +35,10 @@ func TestPoolSubmitAndRetrieveResult(t *testing.T) { // should have one result results := pool.GetResults() if assert.Len(t, results, 1) { - assert.Equal(t, 42, results[0]) + assert.Equal(t, 42, results[0].Value) + // without errors + assert.NoError(t, results[0].Error) } - - // no errors - assert.Nil(t, pool.GetErrors()) } func TestPoolHandleErrors(t *testing.T) { @@ -53,12 +52,12 @@ func TestPoolHandleErrors(t *testing.T) { } pool.Submit(task) - // Wait for error collector to process + // Wait for collector to process the results pool.Wait() - err := pool.GetErrors() - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "task error") + errs := pool.GetErrors() + if assert.Len(t, errs, 1) { + assert.Equal(t, errs[0].Error(), "task error") } } @@ -86,12 +85,26 @@ func TestPoolMultipleTasksWithErrors(t *testing.T) { // Wait for error collector to process pool.Wait() - results := pool.GetResults() - assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results) - err := pool.GetErrors() - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "task error") - } + // Access results together + assert.ElementsMatch(t, + []workerpool.Result[*test]{ + {&test{1}, nil}, + {&test{2}, nil}, + {&test{3}, nil}, + {nil, errors.New("task error")}, + }, + pool.GetResults(), + ) + + // You can also access values and errors directly + assert.ElementsMatch(t, + []*test{nil, &test{1}, &test{2}, &test{3}}, + pool.GetValues(), + ) + assert.ElementsMatch(t, + []error{nil, nil, errors.New("task error"), nil}, + pool.GetErrors(), + ) } func TestPoolHandlesNilTasks(t *testing.T) { @@ -104,8 +117,8 @@ func TestPoolHandlesNilTasks(t *testing.T) { pool.Wait() - err := pool.GetErrors() - assert.NoError(t, err) + assert.Empty(t, pool.GetErrors()) + assert.Empty(t, pool.GetValues()) } func TestPoolProcessing(t *testing.T) { @@ -126,9 +139,8 @@ func TestPoolProcessing(t *testing.T) { // wait pool.Wait() - // read results - result := pool.GetResults() - assert.Equal(t, []int{10}, result) + // read values + assert.Equal(t, []int{10}, pool.GetValues()) // should not longer be processing assert.False(t, pool.Processing()) diff --git a/internal/workerpool/worker.go b/internal/workerpool/worker.go index 77b5c81f1..31257353c 100644 --- a/internal/workerpool/worker.go +++ b/internal/workerpool/worker.go @@ -6,25 +6,14 @@ package workerpool type worker[R any] struct { id int queueCh <-chan Task[R] - resultsCh chan<- R - errorsCh chan<- error + resultsCh chan<- Result[R] } func (w *worker[R]) start() { go func() { for task := range w.queueCh { - if task == nil { - // let the collector know we processed the request - w.errorsCh <- nil - continue - } - data, err := task() - if err != nil { - w.errorsCh <- err - } else { - w.resultsCh <- data - } + w.resultsCh <- Result[R]{data, err} } }() } diff --git a/providers/github/resources/github_org.go b/providers/github/resources/github_org.go index ef39f9715..0876e4a71 100644 --- a/providers/github/resources/github_org.go +++ b/providers/github/resources/github_org.go @@ -4,12 +4,12 @@ package resources import ( - "errors" "slices" "strconv" "strings" "time" + "github.com/cockroachdb/errors" "github.com/google/go-github/v67/github" "github.com/rs/zerolog/log" "go.mondoo.com/cnquery/v11/internal/workerpool" @@ -287,7 +287,7 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { for { // exit as soon as we collect all repositories - reposLen := len(slices.Concat(workerPool.GetResults()...)) + reposLen := len(slices.Concat(workerPool.GetValues()...)) if reposLen >= int(repoCount) { break } @@ -303,7 +303,8 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { listOpts.Page++ // check if any request failed - if err := workerPool.GetErrors(); err != nil { + if errs := workerPool.GetErrors(); len(errs) != 0 { + err := errors.Join(errs...) if strings.Contains(err.Error(), "404") { return nil, nil } @@ -316,7 +317,7 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) { } res := []interface{}{} - for _, repos := range workerPool.GetResults() { + for _, repos := range workerPool.GetValues() { for i := range repos { repo := repos[i]