Skip to content

Commit

Permalink
⚙️ return pool.Result as a combined struct
Browse files Browse the repository at this point in the history
Signed-off-by: Salim Afiune Maya <[email protected]>
  • Loading branch information
afiune committed Dec 14, 2024
1 parent e3a58fd commit 4ffb815
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 74 deletions.
4 changes: 3 additions & 1 deletion explorer/scan/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ func discoverAssets(rootAssetWithRuntime *AssetWithRuntime, resolvedRootAsset *i
pool.Wait()

// Get all assets with runtimes from the pool
for _, assetWithRuntime := range pool.GetResults() {
for _, result := range pool.GetResults() {
assetWithRuntime := result.Value

// If asset is nil, then we observed a duplicate asset with a
// runtime that already exists.
if assetWithRuntime == nil {
Expand Down
38 changes: 21 additions & 17 deletions internal/workerpool/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ import (
)

type collector[R any] struct {
resultsCh <-chan R
results []R
resultsCh <-chan Result[R]
results []Result[R]
read sync.Mutex

errorsCh <-chan error
errors []error

// The total number of requests read.
requestsRead int64
}

Expand All @@ -27,29 +25,35 @@ func (c *collector[R]) start() {
c.read.Lock()
c.results = append(c.results, result)
c.read.Unlock()

case err := <-c.errorsCh:
c.read.Lock()
c.errors = append(c.errors, err)
c.read.Unlock()
}

atomic.AddInt64(&c.requestsRead, 1)
}
}()
}
func (c *collector[R]) GetResults() []R {

func (c *collector[R]) RequestsRead() int64 {
return atomic.LoadInt64(&c.requestsRead)
}

func (c *collector[R]) GetResults() []Result[R] {
c.read.Lock()
defer c.read.Unlock()
return c.results
}

func (c *collector[R]) GetErrors() []error {
c.read.Lock()
defer c.read.Unlock()
return c.errors
func (c *collector[R]) GetValues() (slice []R) {
results := c.GetResults()
for i := range results {
slice = append(slice, results[i].Value)
}
return
}

func (c *collector[R]) RequestsRead() int64 {
return atomic.LoadInt64(&c.requestsRead)
func (c *collector[R]) GetErrors() (slice []error) {
results := c.GetResults()
for i := range results {
slice = append(slice, results[i].Error)
}
return
}
64 changes: 44 additions & 20 deletions internal/workerpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,40 @@ import (
"sync"
"sync/atomic"
"time"

"github.com/cockroachdb/errors"
)

// Represent the tasks that can be sent to the pool.
type Task[R any] func() (result R, err error)

// The result generated from a task.
type Result[R any] struct {
Value R
Error error
}

// Pool is a generic pool of workers.
type Pool[R any] struct {
queueCh chan Task[R]
resultsCh chan R
errorsCh chan error
// The queue where tasks are submitted.
queueCh chan Task[R]

// Where workers send the results after a task is executed,
// the collector then reads them and aggregate them.
resultsCh chan Result[R]

// The total number of requests sent.
requestsSent int64
once sync.Once

workers []*worker[R]
// Number of workers to spawn.
workerCount int

// The list of workers that are listening to the queue.
workers []*worker[R]

// A single collector to aggregate results.
collector[R]

// used to protect starting the pool multiple times
once sync.Once
}

// New initializes a new Pool with the provided number of workers. The pool is generic and can
Expand All @@ -37,14 +52,12 @@ type Pool[R any] struct {
// return 42, nil
// }
func New[R any](count int) *Pool[R] {
resultsCh := make(chan R)
errorsCh := make(chan error)
resultsCh := make(chan Result[R])
return &Pool[R]{
queueCh: make(chan Task[R]),
resultsCh: resultsCh,
errorsCh: errorsCh,
workerCount: count,
collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh},
collector: collector[R]{resultsCh: resultsCh},
}
}

Expand All @@ -56,7 +69,7 @@ func New[R any](count int) *Pool[R] {
func (p *Pool[R]) Start() {
p.once.Do(func() {
for i := 0; i < p.workerCount; i++ {
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh}
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh}
w.start()
p.workers = append(p.workers, &w)
}
Expand All @@ -67,22 +80,33 @@ func (p *Pool[R]) Start() {

// Submit sends a task to the workers
func (p *Pool[R]) Submit(t Task[R]) {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}

// GetErrors returns any error from a processed task
func (p *Pool[R]) GetErrors() error {
return errors.Join(p.collector.GetErrors()...)
if t != nil {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}
}

// GetResults returns the tasks results.
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetResults() []R {
func (p *Pool[R]) GetResults() []Result[R] {
return p.collector.GetResults()
}

// GetValues returns only the values of the pool results
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetValues() []R {
return p.collector.GetValues()
}

// GetErrors returns only the errors of the pool results
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GettErrors() []error {
return p.collector.GetErrors()
}

// Close waits for workers and collector to process all the requests, and then closes
// the task queue channel. After closing the pool, calling `Submit()` will panic.
func (p *Pool[R]) Close() {
Expand Down
50 changes: 31 additions & 19 deletions internal/workerpool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ func TestPoolSubmitAndRetrieveResult(t *testing.T) {
// should have one result
results := pool.GetResults()
if assert.Len(t, results, 1) {
assert.Equal(t, 42, results[0])
assert.Equal(t, 42, results[0].Value)
// without errors
assert.NoError(t, results[0].Error)
}

// no errors
assert.Nil(t, pool.GetErrors())
}

func TestPoolHandleErrors(t *testing.T) {
Expand All @@ -53,12 +52,12 @@ func TestPoolHandleErrors(t *testing.T) {
}
pool.Submit(task)

// Wait for error collector to process
// Wait for collector to process the results
pool.Wait()

err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
errs := pool.GetErrors()
if assert.Len(t, errs, 1) {
assert.Equal(t, errs[0].Error(), "task error")
}
}

Expand Down Expand Up @@ -86,12 +85,26 @@ func TestPoolMultipleTasksWithErrors(t *testing.T) {
// Wait for error collector to process
pool.Wait()

results := pool.GetResults()
assert.ElementsMatch(t, []*test{&test{1}, &test{2}, &test{3}}, results)
err := pool.GetErrors()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "task error")
}
// Access results together
assert.ElementsMatch(t,
[]workerpool.Result[*test]{
{&test{1}, nil},
{&test{2}, nil},
{&test{3}, nil},
{nil, errors.New("task error")},
},
pool.GetResults(),
)

// You can also access values and errors directly
assert.ElementsMatch(t,
[]*test{nil, &test{1}, &test{2}, &test{3}},
pool.GetValues(),
)
assert.ElementsMatch(t,
[]error{nil, nil, errors.New("task error"), nil},
pool.GetErrors(),
)
}

func TestPoolHandlesNilTasks(t *testing.T) {
Expand All @@ -104,8 +117,8 @@ func TestPoolHandlesNilTasks(t *testing.T) {

pool.Wait()

err := pool.GetErrors()
assert.NoError(t, err)
assert.Empty(t, pool.GetErrors())
assert.Empty(t, pool.GetValues())
}

func TestPoolProcessing(t *testing.T) {
Expand All @@ -126,9 +139,8 @@ func TestPoolProcessing(t *testing.T) {
// wait
pool.Wait()

// read results
result := pool.GetResults()
assert.Equal(t, []int{10}, result)
// read values
assert.Equal(t, []int{10}, pool.GetValues())

// should not longer be processing
assert.False(t, pool.Processing())
Expand Down
15 changes: 2 additions & 13 deletions internal/workerpool/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,14 @@ package workerpool
type worker[R any] struct {
id int
queueCh <-chan Task[R]
resultsCh chan<- R
errorsCh chan<- error
resultsCh chan<- Result[R]
}

func (w *worker[R]) start() {
go func() {
for task := range w.queueCh {
if task == nil {
// let the collector know we processed the request
w.errorsCh <- nil
continue
}

data, err := task()
if err != nil {
w.errorsCh <- err
} else {
w.resultsCh <- data
}
w.resultsCh <- Result[R]{data, err}
}
}()
}
9 changes: 5 additions & 4 deletions providers/github/resources/github_org.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
package resources

import (
"errors"
"slices"
"strconv"
"strings"
"time"

"github.com/cockroachdb/errors"
"github.com/google/go-github/v67/github"
"github.com/rs/zerolog/log"
"go.mondoo.com/cnquery/v11/internal/workerpool"
Expand Down Expand Up @@ -287,7 +287,7 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) {

for {
// exit as soon as we collect all repositories
reposLen := len(slices.Concat(workerPool.GetResults()...))
reposLen := len(slices.Concat(workerPool.GetValues()...))
if reposLen >= int(repoCount) {
break
}
Expand All @@ -303,7 +303,8 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) {
listOpts.Page++

// check if any request failed
if err := workerPool.GetErrors(); err != nil {
if errs := workerPool.GetErrors(); len(errs) != 0 {
err := errors.Join(errs...)
if strings.Contains(err.Error(), "404") {
return nil, nil
}
Expand All @@ -316,7 +317,7 @@ func (g *mqlGithubOrganization) repositories() ([]interface{}, error) {
}

res := []interface{}{}
for _, repos := range workerPool.GetResults() {
for _, repos := range workerPool.GetValues() {
for i := range repos {
repo := repos[i]

Expand Down

0 comments on commit 4ffb815

Please sign in to comment.