Skip to content

Commit

Permalink
⚡ discover assets in parallel (#4973)
Browse files Browse the repository at this point in the history
This change use the new `workerpool` package to discover assets in parallel.

There were a few places that I detected race conditions that I have fixed, I
also added a race detector `make` command which is running as a CI job.

For testing:

Using the `github` provider for testing, scanning an **organization that has around 3k repositories**.

**Before (~15 Minutes)**
```
TRC logger.FuncDur> func=explorer.discoverAssets took=890303.455042
```
**After (~2 minutes)**
```
TRC logger.FuncDur> func=explorer.discoverAssets took=124293.542166
```

Race Detector:

You can now run `make race/go` to check for race conditions.
```
$ make race/go
go test -race go.mondoo.com/cnquery/v11/internal/workerpool
ok  	go.mondoo.com/cnquery/v11/internal/workerpool	16.417s
go test -race go.mondoo.com/cnquery/v11/explorer/scan
ok  	go.mondoo.com/cnquery/v11/explorer/scan	2.487s
```

Additional commit history:

* ⚡ workerpool package to submit parallel requests
* ⚡ fetch org repositories in parallel
* ⚙️  add a collector to the workerpool
* 🚨 fix race conditions
* ⚡ discover assets in parallel
* 🧪 decrease workerpool wait ticker to 10ms
* 🐛 make `DiscoveredAssets.AddError()` thread safe
* 🧵 add mutex when running `provider.connect()`
* ⚙️  reduce the workerpool Task function
* ⚙️  return `pool.Result` as a combined struct
* 🏎️ fix more data race conditions
* 🤖 Run race detector on CI
* ⚙️  split plugin connect func and assignation

---------

Signed-off-by: Salim Afiune Maya <[email protected]>
  • Loading branch information
afiune authored Dec 16, 2024
1 parent b9bcf90 commit fee2739
Show file tree
Hide file tree
Showing 13 changed files with 212 additions and 102 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ jobs:
uses: actions/cache/save@v4
with:
path: ./cache
key: ${{ runner.os }}-benchmark-${{ github.run_id }}
key: ${{ runner.os }}-benchmark-${{ github.run_id }}
18 changes: 18 additions & 0 deletions .github/workflows/pr-test-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,24 @@ jobs:
name: test-results-cli
path: report.xml

go-race:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Import environment variables from file
run: cat ".github/env" >> $GITHUB_ENV

- name: Install Go
uses: actions/setup-go@v5
with:
go-version: ">=${{ env.golang-version }}"
cache: false

- name: Run race detector on selected packages
run: make race/go

go-bench:
runs-on: ubuntu-latest
if: github.ref != 'refs/heads/main'
Expand Down
12 changes: 12 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@
"shell", "ssh", "[email protected]",
],
},
{
"name": "scan github org",
"type": "go",
"request": "launch",
"program": "${workspaceRoot}/apps/cnquery/cnquery.go",
"args": [
"scan",
"github",
"org", "hit-training",
"--log-level", "trace"
]
},
{
"name": "Configure Built-in Providers",
"type": "go",
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,10 @@ test: test/go test/lint
benchmark/go:
go test -bench=. -benchmem go.mondoo.com/cnquery/v11/explorer/scan/benchmark

race/go:
go test -race go.mondoo.com/cnquery/v11/internal/workerpool
go test -race go.mondoo.com/cnquery/v11/explorer/scan

test/generate: prep/tools/mockgen
go generate ./providers

Expand Down
50 changes: 35 additions & 15 deletions explorer/scan/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ package scan
import (
"context"
"errors"
"sync"
"time"

"github.com/rs/zerolog/log"
"go.mondoo.com/cnquery/v11/cli/config"
"go.mondoo.com/cnquery/v11/cli/execruntime"
"go.mondoo.com/cnquery/v11/internal/workerpool"
"go.mondoo.com/cnquery/v11/llx"
"go.mondoo.com/cnquery/v11/logger"
"go.mondoo.com/cnquery/v11/providers"
Expand All @@ -20,6 +22,9 @@ import (
"go.mondoo.com/cnquery/v11/providers-sdk/v1/upstream"
)

// number of parallel goroutines discovering assets
const workers = 10

type AssetWithRuntime struct {
Asset *inventory.Asset
Runtime *providers.Runtime
Expand All @@ -34,28 +39,30 @@ type DiscoveredAssets struct {
platformIds map[string]struct{}
Assets []*AssetWithRuntime
Errors []*AssetWithError
assetsLock sync.Mutex
}

// Add adds an asset and its runtime to the discovered assets list. It returns true if the
// asset has been added, false if it is a duplicate
func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtime) bool {
isDuplicate := false
d.assetsLock.Lock()
defer d.assetsLock.Unlock()

for _, platformId := range asset.PlatformIds {
if _, ok := d.platformIds[platformId]; ok {
isDuplicate = true
break
// duplicate
return false
}
d.platformIds[platformId] = struct{}{}
}
if isDuplicate {
return false
}

d.Assets = append(d.Assets, &AssetWithRuntime{Asset: asset, Runtime: runtime})
return true
}

func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) {
d.assetsLock.Lock()
defer d.assetsLock.Unlock()
d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err})
}

Expand Down Expand Up @@ -161,17 +168,30 @@ func discoverAssets(rootAssetWithRuntime *AssetWithRuntime, resolvedRootAsset *i
return
}

pool := workerpool.New[*AssetWithRuntime](workers)
pool.Start()
defer pool.Close()

// for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset
for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets {
// create runtime for root asset
assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording)
if err != nil {
log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset")
discoveredAssets.AddError(a, err)
continue
}
for _, asset := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets {
pool.Submit(func() (*AssetWithRuntime, error) {
assetWithRuntime, err := createRuntimeForAsset(asset, upstream, recording)
if err != nil {
log.Error().Err(err).Str("asset", asset.GetName()).Msg("unable to create runtime for asset")
discoveredAssets.AddError(asset, err)
}
return assetWithRuntime, nil
})
}

// Wait for the workers to finish processing
pool.Wait()

// Get all assets with runtimes from the pool
for _, result := range pool.GetResults() {
assetWithRuntime := result.Value

// If no asset was returned and no error, then we observed a duplicate asset with a
// If asset is nil, then we observed a duplicate asset with a
// runtime that already exists.
if assetWithRuntime == nil {
continue
Expand Down
38 changes: 21 additions & 17 deletions internal/workerpool/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ import (
)

type collector[R any] struct {
resultsCh <-chan R
results []R
resultsCh <-chan Result[R]
results []Result[R]
read sync.Mutex

errorsCh <-chan error
errors []error

// The total number of requests read.
requestsRead int64
}

Expand All @@ -27,29 +25,35 @@ func (c *collector[R]) start() {
c.read.Lock()
c.results = append(c.results, result)
c.read.Unlock()

case err := <-c.errorsCh:
c.read.Lock()
c.errors = append(c.errors, err)
c.read.Unlock()
}

atomic.AddInt64(&c.requestsRead, 1)
}
}()
}
func (c *collector[R]) GetResults() []R {

func (c *collector[R]) RequestsRead() int64 {
return atomic.LoadInt64(&c.requestsRead)
}

func (c *collector[R]) GetResults() []Result[R] {
c.read.Lock()
defer c.read.Unlock()
return c.results
}

func (c *collector[R]) GetErrors() []error {
c.read.Lock()
defer c.read.Unlock()
return c.errors
func (c *collector[R]) GetValues() (slice []R) {
results := c.GetResults()
for i := range results {
slice = append(slice, results[i].Value)
}
return
}

func (c *collector[R]) RequestsRead() int64 {
return atomic.LoadInt64(&c.requestsRead)
func (c *collector[R]) GetErrors() (slice []error) {
results := c.GetResults()
for i := range results {
slice = append(slice, results[i].Error)
}
return
}
66 changes: 45 additions & 21 deletions internal/workerpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,40 @@ import (
"sync"
"sync/atomic"
"time"

"github.com/cockroachdb/errors"
)

// Represent the tasks that can be sent to the pool.
type Task[R any] func() (result R, err error)

// The result generated from a task.
type Result[R any] struct {
Value R
Error error
}

// Pool is a generic pool of workers.
type Pool[R any] struct {
queueCh chan Task[R]
resultsCh chan R
errorsCh chan error
// The queue where tasks are submitted.
queueCh chan Task[R]

// Where workers send the results after a task is executed,
// the collector then reads them and aggregate them.
resultsCh chan Result[R]

// The total number of requests sent.
requestsSent int64
once sync.Once

workers []*worker[R]
// Number of workers to spawn.
workerCount int

// The list of workers that are listening to the queue.
workers []*worker[R]

// A single collector to aggregate results.
collector[R]

// used to protect starting the pool multiple times
once sync.Once
}

// New initializes a new Pool with the provided number of workers. The pool is generic and can
Expand All @@ -37,14 +52,12 @@ type Pool[R any] struct {
// return 42, nil
// }
func New[R any](count int) *Pool[R] {
resultsCh := make(chan R)
errorsCh := make(chan error)
resultsCh := make(chan Result[R])
return &Pool[R]{
queueCh: make(chan Task[R]),
resultsCh: resultsCh,
errorsCh: errorsCh,
workerCount: count,
collector: collector[R]{resultsCh: resultsCh, errorsCh: errorsCh},
collector: collector[R]{resultsCh: resultsCh},
}
}

Expand All @@ -56,7 +69,7 @@ func New[R any](count int) *Pool[R] {
func (p *Pool[R]) Start() {
p.once.Do(func() {
for i := 0; i < p.workerCount; i++ {
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh, errorsCh: p.errorsCh}
w := worker[R]{id: i, queueCh: p.queueCh, resultsCh: p.resultsCh}
w.start()
p.workers = append(p.workers, &w)
}
Expand All @@ -67,22 +80,33 @@ func (p *Pool[R]) Start() {

// Submit sends a task to the workers
func (p *Pool[R]) Submit(t Task[R]) {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}

// GetErrors returns any error from a processed task
func (p *Pool[R]) GetErrors() error {
return errors.Join(p.collector.GetErrors()...)
if t != nil {
p.queueCh <- t
atomic.AddInt64(&p.requestsSent, 1)
}
}

// GetResults returns the tasks results.
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetResults() []R {
func (p *Pool[R]) GetResults() []Result[R] {
return p.collector.GetResults()
}

// GetValues returns only the values of the pool results
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GetValues() []R {
return p.collector.GetValues()
}

// GetErrors returns only the errors of the pool results
//
// It is recommended to call `Wait()` before reading the results.
func (p *Pool[R]) GettErrors() []error {
return p.collector.GetErrors()
}

// Close waits for workers and collector to process all the requests, and then closes
// the task queue channel. After closing the pool, calling `Submit()` will panic.
func (p *Pool[R]) Close() {
Expand All @@ -92,7 +116,7 @@ func (p *Pool[R]) Close() {

// Wait waits until all tasks have been processed.
func (p *Pool[R]) Wait() {
ticker := time.NewTicker(100 * time.Millisecond)
ticker := time.NewTicker(10 * time.Millisecond)
for {
if !p.Processing() {
return
Expand Down
Loading

0 comments on commit fee2739

Please sign in to comment.