From 9976616773313bb56d052996be3f0d5fcee99d4a Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:59:10 +0200 Subject: [PATCH] context propagation: StreamingAcquisition() (#3274) * context propagation: StreamingAcquisition() * lint * ship with codecov.yml --- .github/generate-codecov-yml.sh | 3 + .github/workflows/go-tests.yml | 4 -- cmd/crowdsec/crowdsec.go | 2 +- pkg/acquisition/acquisition.go | 56 ++++++++++++++----- pkg/acquisition/acquisition_test.go | 52 +++++++++-------- pkg/acquisition/modules/appsec/appsec.go | 12 ++-- .../modules/cloudwatch/cloudwatch.go | 46 ++++++++++----- .../modules/cloudwatch/cloudwatch_test.go | 13 +++-- pkg/acquisition/modules/docker/docker.go | 8 +-- pkg/acquisition/modules/docker/docker_test.go | 5 +- pkg/acquisition/modules/file/file.go | 3 +- pkg/acquisition/modules/file/file_test.go | 4 +- .../modules/journalctl/journalctl.go | 8 +-- .../modules/journalctl/journalctl_test.go | 4 +- pkg/acquisition/modules/kafka/kafka.go | 6 +- pkg/acquisition/modules/kafka/kafka_test.go | 14 +++-- pkg/acquisition/modules/kinesis/kinesis.go | 34 +++++------ .../modules/kinesis/kinesis_test.go | 26 +++++---- .../modules/kubernetesaudit/k8s_audit.go | 3 +- .../modules/kubernetesaudit/k8s_audit_test.go | 9 ++- pkg/acquisition/modules/loki/loki.go | 6 +- pkg/acquisition/modules/loki/loki_test.go | 9 ++- pkg/acquisition/modules/s3/s3.go | 21 ++++--- pkg/acquisition/modules/s3/s3_test.go | 6 +- pkg/acquisition/modules/syslog/syslog.go | 9 +-- pkg/acquisition/modules/syslog/syslog_test.go | 16 ++++-- .../modules/wineventlog/wineventlog.go | 3 +- .../modules/wineventlog/wineventlog_test.go | 4 +- .../wineventlog/wineventlog_windows.go | 3 +- 29 files changed, 235 insertions(+), 154 deletions(-) diff --git a/.github/generate-codecov-yml.sh b/.github/generate-codecov-yml.sh index cc2d652e339..ddb60d0ce80 100755 --- a/.github/generate-codecov-yml.sh +++ b/.github/generate-codecov-yml.sh @@ -7,6 +7,9 @@ cat <> .github/codecov.yml - - name: Upload unit coverage to Codecov uses: codecov/codecov-action@v4 with: diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index 460e8ab4328..c44d71d2093 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -169,7 +169,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H log.Info("Starting processing data") - if err := acquisition.StartAcquisition(dataSources, inputLineChan, &acquisTomb); err != nil { + if err := acquisition.StartAcquisition(context.TODO(), dataSources, inputLineChan, &acquisTomb); err != nil { return fmt.Errorf("starting acquisition error: %w", err) } diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 4a5226a2981..4519ea7392b 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -1,6 +1,7 @@ package acquisition import ( + "context" "errors" "fmt" "io" @@ -39,17 +40,17 @@ func (e *DataSourceUnavailableError) Unwrap() error { // The interface each datasource must implement type DataSource interface { - GetMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module - GetAggregMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module (aggregated mode, limits cardinality) - UnmarshalConfig([]byte) error // Decode and pre-validate the YAML datasource - anything that can be checked before runtime - Configure([]byte, *log.Entry, int) error // Complete the YAML datasource configuration and perform runtime checks. - ConfigureByDSN(string, map[string]string, *log.Entry, string) error // Configure the datasource - GetMode() string // Get the mode (TAIL, CAT or SERVER) - GetName() string // Get the name of the module - OneShotAcquisition(chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) - StreamingAcquisition(chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) - CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) - GetUuid() string // Get the unique identifier of the datasource + GetMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module + GetAggregMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module (aggregated mode, limits cardinality) + UnmarshalConfig([]byte) error // Decode and pre-validate the YAML datasource - anything that can be checked before runtime + Configure([]byte, *log.Entry, int) error // Complete the YAML datasource configuration and perform runtime checks. + ConfigureByDSN(string, map[string]string, *log.Entry, string) error // Configure the datasource + GetMode() string // Get the mode (TAIL, CAT or SERVER) + GetName() string // Get the name of the module + OneShotAcquisition(chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) + StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) + CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) + GetUuid() string // Get the unique identifier of the datasource Dump() interface{} } @@ -242,8 +243,10 @@ func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg, prom *csconfig for { var sub configuration.DataSourceCommonCfg - err = dec.Decode(&sub) + idx += 1 + + err = dec.Decode(&sub) if err != nil { if !errors.Is(err, io.EOF) { return nil, fmt.Errorf("failed to yaml decode %s: %w", acquisFile, err) @@ -283,6 +286,7 @@ func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg, prom *csconfig uniqueId := uuid.NewString() sub.UniqueId = uniqueId + src, err := DataSourceConfigure(sub, metrics_level) if err != nil { var dserr *DataSourceUnavailableError @@ -290,29 +294,36 @@ func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg, prom *csconfig log.Error(err) continue } + return nil, fmt.Errorf("while configuring datasource of type %s from %s (position: %d): %w", sub.Source, acquisFile, idx, err) } + if sub.TransformExpr != "" { vm, err := expr.Compile(sub.TransformExpr, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return nil, fmt.Errorf("while compiling transform expression '%s' for datasource %s in %s (position: %d): %w", sub.TransformExpr, sub.Source, acquisFile, idx, err) } + transformRuntimes[uniqueId] = vm } + sources = append(sources, *src) } } + return sources, nil } func GetMetrics(sources []DataSource, aggregated bool) error { var metrics []prometheus.Collector + for i := range sources { if aggregated { metrics = sources[i].GetMetrics() } else { metrics = sources[i].GetAggregMetrics() } + for _, metric := range metrics { if err := prometheus.Register(metric); err != nil { if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { @@ -322,12 +333,14 @@ func GetMetrics(sources []DataSource, aggregated bool) error { } } } + return nil } func transform(transformChan chan types.Event, output chan types.Event, AcquisTomb *tomb.Tomb, transformRuntime *vm.Program, logger *log.Entry) { defer trace.CatchPanic("crowdsec/acquis") logger.Infof("transformer started") + for { select { case <-AcquisTomb.Dying(): @@ -335,15 +348,18 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo return case evt := <-transformChan: logger.Tracef("Received event %s", evt.Line.Raw) + out, err := expr.Run(transformRuntime, map[string]interface{}{"evt": &evt}) if err != nil { logger.Errorf("while running transform expression: %s, sending event as-is", err) output <- evt } + if out == nil { logger.Errorf("transform expression returned nil, sending event as-is") output <- evt } + switch v := out.(type) { case string: logger.Tracef("transform expression returned %s", v) @@ -351,18 +367,22 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo output <- evt case []interface{}: logger.Tracef("transform expression returned %v", v) //nolint:asasalint // We actually want to log the slice content + for _, line := range v { l, ok := line.(string) if !ok { logger.Errorf("transform expression returned []interface{}, but cannot assert an element to string") output <- evt + continue } + evt.Line.Raw = l output <- evt } case []string: logger.Tracef("transform expression returned %v", v) + for _, line := range v { evt.Line.Raw = line output <- evt @@ -375,7 +395,7 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo } } -func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { +func StartAcquisition(ctx context.Context, sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { // Don't wait if we have no sources, as it will hang forever if len(sources) == 0 { return nil @@ -387,32 +407,40 @@ func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb AcquisTomb.Go(func() error { defer trace.CatchPanic("crowdsec/acquis") + var err error outChan := output + log.Debugf("datasource %s UUID: %s", subsrc.GetName(), subsrc.GetUuid()) + if transformRuntime, ok := transformRuntimes[subsrc.GetUuid()]; ok { log.Infof("transform expression found for datasource %s", subsrc.GetName()) + transformChan := make(chan types.Event) outChan = transformChan transformLogger := log.WithFields(log.Fields{ "component": "transform", "datasource": subsrc.GetName(), }) + AcquisTomb.Go(func() error { transform(outChan, output, AcquisTomb, transformRuntime, transformLogger) return nil }) } + if subsrc.GetMode() == configuration.TAIL_MODE { - err = subsrc.StreamingAcquisition(outChan, AcquisTomb) + err = subsrc.StreamingAcquisition(ctx, outChan, AcquisTomb) } else { err = subsrc.OneShotAcquisition(outChan, AcquisTomb) } + if err != nil { // if one of the acqusition returns an error, we kill the others to properly shutdown AcquisTomb.Kill(err) } + return nil }) } diff --git a/pkg/acquisition/acquisition_test.go b/pkg/acquisition/acquisition_test.go index e39199f9cdb..e82b3df54c2 100644 --- a/pkg/acquisition/acquisition_test.go +++ b/pkg/acquisition/acquisition_test.go @@ -1,6 +1,7 @@ package acquisition import ( + "context" "errors" "fmt" "strings" @@ -56,14 +57,16 @@ func (f *MockSource) Configure(cfg []byte, logger *log.Entry, metricsLevel int) return nil } -func (f *MockSource) GetMode() string { return f.Mode } -func (f *MockSource) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) CanRun() error { return nil } -func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } -func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } -func (f *MockSource) Dump() interface{} { return f } -func (f *MockSource) GetName() string { return "mock" } +func (f *MockSource) GetMode() string { return f.Mode } +func (f *MockSource) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } +func (f *MockSource) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} +func (f *MockSource) CanRun() error { return nil } +func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } +func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } +func (f *MockSource) Dump() interface{} { return f } +func (f *MockSource) GetName() string { return "mock" } func (f *MockSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { return errors.New("not supported") } @@ -327,7 +330,7 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro return nil } -func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { +func (f *MockCat) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { return errors.New("can't run in tail") } func (f *MockCat) CanRun() error { return nil } @@ -366,7 +369,7 @@ func (f *MockTail) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) err return errors.New("can't run in cat mode") } -func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *MockTail) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { for range 10 { evt := types.Event{} evt.Line.Src = "test" @@ -389,6 +392,7 @@ func (f *MockTail) GetUuid() string { return "" } // func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { func TestStartAcquisitionCat(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockCat{}, } @@ -396,7 +400,7 @@ func TestStartAcquisitionCat(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -416,6 +420,7 @@ READLOOP: } func TestStartAcquisitionTail(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTail{}, } @@ -423,7 +428,7 @@ func TestStartAcquisitionTail(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -450,7 +455,7 @@ type MockTailError struct { MockTail } -func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *MockTailError) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { for range 10 { evt := types.Event{} evt.Line.Src = "test" @@ -463,6 +468,7 @@ func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) } func TestStartAcquisitionTailError(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTailError{}, } @@ -470,7 +476,7 @@ func TestStartAcquisitionTailError(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { t.Errorf("expected error, got '%s'", err) } }() @@ -501,14 +507,16 @@ func (f *MockSourceByDSN) UnmarshalConfig(cfg []byte) error { return nil } func (f *MockSourceByDSN) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { return nil } -func (f *MockSourceByDSN) GetMode() string { return f.Mode } -func (f *MockSourceByDSN) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) CanRun() error { return nil } -func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } -func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } -func (f *MockSourceByDSN) Dump() interface{} { return f } -func (f *MockSourceByDSN) GetName() string { return "mockdsn" } +func (f *MockSourceByDSN) GetMode() string { return f.Mode } +func (f *MockSourceByDSN) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } +func (f *MockSourceByDSN) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} +func (f *MockSourceByDSN) CanRun() error { return nil } +func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } +func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } +func (f *MockSourceByDSN) Dump() interface{} { return f } +func (f *MockSourceByDSN) GetName() string { return "mockdsn" } func (f *MockSourceByDSN) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { dsn = strings.TrimPrefix(dsn, "mockdsn://") if dsn != "test_expect" { diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index 8a93326c7e3..5161b631c33 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -59,7 +59,7 @@ type AppsecSource struct { AppsecConfigs map[string]appsec.AppsecConfig lapiURL string AuthCache AuthCache - AppsecRunners []AppsecRunner //one for each go-routine + AppsecRunners []AppsecRunner // one for each go-routine } // Struct to handle cache of authentication @@ -172,7 +172,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe w.InChan = make(chan appsec.ParsedRequest) appsecCfg := appsec.AppsecConfig{Logger: w.logger.WithField("component", "appsec_config")} - //let's load the associated appsec_config: + // let's load the associated appsec_config: if w.config.AppsecConfigPath != "" { err := appsecCfg.LoadByPath(w.config.AppsecConfigPath) if err != nil { @@ -201,7 +201,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe for nbRoutine := range w.config.Routines { appsecRunnerUUID := uuid.New().String() - //we copy AppsecRutime for each runner + // we copy AppsecRutime for each runner wrt := *w.AppsecRuntime wrt.Logger = w.logger.Dup().WithField("runner_uuid", appsecRunnerUUID) runner := AppsecRunner{ @@ -220,7 +220,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe w.logger.Infof("Created %d appsec runners", len(w.AppsecRunners)) - //We don´t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec + // We don´t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec w.mux.HandleFunc(w.config.Path, w.appsecHandler) return nil } @@ -241,7 +241,7 @@ func (w *AppsecSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er return errors.New("AppSec datasource does not support command line acquisition") } -func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { w.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/appsec/live") @@ -292,7 +292,7 @@ func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) }) <-t.Dying() w.logger.Info("Shutting down Appsec server") - //xx let's clean up the appsec runners :) + // xx let's clean up the appsec runners :) appsec.AppsecRulesDetails = make(map[int]appsec.RulesDetails) w.server.Shutdown(context.TODO()) return nil diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch.go b/pkg/acquisition/modules/cloudwatch/cloudwatch.go index d6f33b68050..e4b6c95d77f 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch.go @@ -159,6 +159,7 @@ func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry, Metr if err != nil { return err } + cw.metricsLevel = MetricsLevel cw.logger = logger.WithField("group", cw.Config.GroupName) @@ -175,16 +176,18 @@ func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry, Metr if *cw.Config.MaxStreamAge > *cw.Config.StreamReadTimeout { cw.logger.Warningf("max_stream_age > stream_read_timeout, stream might keep being opened/closed") } + cw.logger.Tracef("aws_config_dir set to %s", *cw.Config.AwsConfigDir) if *cw.Config.AwsConfigDir != "" { _, err := os.Stat(*cw.Config.AwsConfigDir) if err != nil { cw.logger.Errorf("can't read aws_config_dir '%s' got err %s", *cw.Config.AwsConfigDir, err) - return fmt.Errorf("can't read aws_config_dir %s got err %s ", *cw.Config.AwsConfigDir, err) + return fmt.Errorf("can't read aws_config_dir %s got err %w ", *cw.Config.AwsConfigDir, err) } + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - //as aws sdk relies on $HOME, let's allow the user to override it :) + // as aws sdk relies on $HOME, let's allow the user to override it :) os.Setenv("AWS_CONFIG_FILE", fmt.Sprintf("%s/config", *cw.Config.AwsConfigDir)) os.Setenv("AWS_SHARED_CREDENTIALS_FILE", fmt.Sprintf("%s/credentials", *cw.Config.AwsConfigDir)) } else { @@ -192,25 +195,30 @@ func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry, Metr cw.logger.Errorf("aws_region is not specified, specify it or aws_config_dir") return errors.New("aws_region is not specified, specify it or aws_config_dir") } + os.Setenv("AWS_REGION", *cw.Config.AwsRegion) } if err := cw.newClient(); err != nil { return err } + cw.streamIndexes = make(map[string]string) targetStream := "*" + if cw.Config.StreamRegexp != nil { if _, err := regexp.Compile(*cw.Config.StreamRegexp); err != nil { return fmt.Errorf("while compiling regexp '%s': %w", *cw.Config.StreamRegexp, err) } + targetStream = *cw.Config.StreamRegexp } else if cw.Config.StreamName != nil { targetStream = *cw.Config.StreamName } cw.logger.Infof("Adding cloudwatch group '%s' (stream:%s) to datasources", cw.Config.GroupName, targetStream) + return nil } @@ -231,24 +239,29 @@ func (cw *CloudwatchSource) newClient() error { if sess == nil { return errors.New("failed to create aws session") } + if v := os.Getenv("AWS_ENDPOINT_FORCE"); v != "" { cw.logger.Debugf("[testing] overloading endpoint with %s", v) cw.cwClient = cloudwatchlogs.New(sess, aws.NewConfig().WithEndpoint(v)) } else { cw.cwClient = cloudwatchlogs.New(sess) } + if cw.cwClient == nil { return errors.New("failed to create cloudwatch client") } + return nil } -func (cw *CloudwatchSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (cw *CloudwatchSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { cw.t = t monitChan := make(chan LogStreamTailConfig) + t.Go(func() error { return cw.LogStreamManager(monitChan, out) }) + return cw.WatchLogGroupForStreams(monitChan) } @@ -279,6 +292,7 @@ func (cw *CloudwatchSource) Dump() interface{} { func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig) error { cw.logger.Debugf("Starting to watch group (interval:%s)", cw.Config.PollNewStreamInterval) ticker := time.NewTicker(*cw.Config.PollNewStreamInterval) + var startFrom *string for { @@ -289,11 +303,12 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig case <-ticker.C: hasMoreStreams := true startFrom = nil + for hasMoreStreams { cw.logger.Tracef("doing the call to DescribeLogStreamsPagesWithContext") ctx := context.Background() - //there can be a lot of streams in a group, and we're only interested in those recently written to, so we sort by LastEventTime + // there can be a lot of streams in a group, and we're only interested in those recently written to, so we sort by LastEventTime err := cw.cwClient.DescribeLogStreamsPagesWithContext( ctx, &cloudwatchlogs.DescribeLogStreamsInput{ @@ -305,13 +320,14 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig }, func(page *cloudwatchlogs.DescribeLogStreamsOutput, lastPage bool) bool { cw.logger.Tracef("in helper of DescribeLogStreamsPagesWithContext") + for _, event := range page.LogStreams { startFrom = page.NextToken - //we check if the stream has been written to recently enough to be monitored + // we check if the stream has been written to recently enough to be monitored if event.LastIngestionTime != nil { - //aws uses millisecond since the epoch + // aws uses millisecond since the epoch oldest := time.Now().UTC().Add(-*cw.Config.MaxStreamAge) - //TBD : verify that this is correct : Unix 2nd arg expects Nanoseconds, and have a code that is more explicit. + // TBD : verify that this is correct : Unix 2nd arg expects Nanoseconds, and have a code that is more explicit. LastIngestionTime := time.Unix(0, *event.LastIngestionTime*int64(time.Millisecond)) if LastIngestionTime.Before(oldest) { cw.logger.Tracef("stop iteration, %s reached oldest age, stop (%s < %s)", *event.LogStreamName, LastIngestionTime, time.Now().UTC().Add(-*cw.Config.MaxStreamAge)) @@ -319,7 +335,7 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig return false } cw.logger.Tracef("stream %s is elligible for monitoring", *event.LogStreamName) - //the stream has been updated recently, check if we should monitor it + // the stream has been updated recently, check if we should monitor it var expectMode int if !cw.Config.UseTimeMachine { expectMode = types.LIVE @@ -383,7 +399,7 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha for idx, stream := range cw.monitoredStreams { if newStream.GroupName == stream.GroupName && newStream.StreamName == stream.StreamName { - //stream exists, but is dead, remove it from list + // stream exists, but is dead, remove it from list if !stream.t.Alive() { cw.logger.Debugf("stream %s already exists, but is dead", newStream.StreamName) cw.monitoredStreams = append(cw.monitoredStreams[:idx], cw.monitoredStreams[idx+1:]...) @@ -397,7 +413,7 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha } } - //let's start watching this stream + // let's start watching this stream if shouldCreate { if cw.metricsLevel != configuration.METRICS_NONE { openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Inc() @@ -445,7 +461,7 @@ func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan var startFrom *string lastReadMessage := time.Now().UTC() ticker := time.NewTicker(cfg.PollStreamInterval) - //resume at existing index if we already had + // resume at existing index if we already had streamIndexMutex.Lock() v := cw.streamIndexes[cfg.GroupName+"+"+cfg.StreamName] streamIndexMutex.Unlock() @@ -566,7 +582,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, if len(v) != 1 { return errors.New("expected zero or one argument for 'start_date'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported strdate, startDate := parser.GenDateParse(v[0]) cw.logger.Debugf("parsed '%s' as '%s'", v[0], strdate) cw.Config.StartTime = &startDate @@ -574,7 +590,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, if len(v) != 1 { return errors.New("expected zero or one argument for 'end_date'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported strdate, endDate := parser.GenDateParse(v[0]) cw.logger.Debugf("parsed '%s' as '%s'", v[0], strdate) cw.Config.EndTime = &endDate @@ -582,7 +598,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, if len(v) != 1 { return errors.New("expected zero or one argument for 'backlog'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported duration, err := time.ParseDuration(v[0]) if err != nil { return fmt.Errorf("unable to parse '%s' as duration: %w", v[0], err) @@ -618,7 +634,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, } func (cw *CloudwatchSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - //StreamName string, Start time.Time, End time.Time + // StreamName string, Start time.Time, End time.Time config := LogStreamTailConfig{ GroupName: cw.Config.GroupName, StreamName: *cw.Config.StreamName, diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go index bab7593f26f..d62c3f6e3dd 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go @@ -1,6 +1,7 @@ package cloudwatchacquisition import ( + "context" "errors" "fmt" "net" @@ -34,6 +35,7 @@ func deleteAllLogGroups(t *testing.T, cw *CloudwatchSource) { input := &cloudwatchlogs.DescribeLogGroupsInput{} result, err := cw.cwClient.DescribeLogGroups(input) require.NoError(t, err) + for _, group := range result.LogGroups { _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ LogGroupName: group.LogGroupName, @@ -62,18 +64,22 @@ func TestMain(m *testing.M) { if runtime.GOOS == "windows" { os.Exit(0) } + if err := checkForLocalStackAvailability(); err != nil { log.Fatalf("local stack error : %s", err) } + def_PollNewStreamInterval = 1 * time.Second def_PollStreamInterval = 1 * time.Second def_StreamReadTimeout = 10 * time.Second def_MaxStreamAge = 5 * time.Second def_PollDeadStreamInterval = 5 * time.Second + os.Exit(m.Run()) } func TestWatchLogGroupForStreams(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -447,7 +453,7 @@ stream_name: test_stream`), dbgLogger.Infof("running StreamingAcquisition") actmb := tomb.Tomb{} actmb.Go(func() error { - err := cw.StreamingAcquisition(out, &actmb) + err := cw.StreamingAcquisition(ctx, out, &actmb) dbgLogger.Infof("acquis done") cstest.RequireErrorContains(t, err, tc.expectedStartErr) return nil @@ -503,7 +509,6 @@ stream_name: test_stream`), if len(res) != 0 { t.Fatalf("leftover unmatched results : %v", res) } - } if tc.teardown != nil { tc.teardown(t, &cw) @@ -513,6 +518,7 @@ stream_name: test_stream`), } func TestConfiguration(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -571,7 +577,7 @@ stream_name: test_stream`), switch cw.GetMode() { case "tail": - err = cw.StreamingAcquisition(out, &tmb) + err = cw.StreamingAcquisition(ctx, out, &tmb) case "cat": err = cw.OneShotAcquisition(out, &tmb) } @@ -798,7 +804,6 @@ func TestOneShotAcquisition(t *testing.T) { if len(res) != 0 { t.Fatalf("leftover unmatched results : %v", res) } - } if tc.teardown != nil { tc.teardown(t, &cw) diff --git a/pkg/acquisition/modules/docker/docker.go b/pkg/acquisition/modules/docker/docker.go index 44fee0a99a2..874b1556fd5 100644 --- a/pkg/acquisition/modules/docker/docker.go +++ b/pkg/acquisition/modules/docker/docker.go @@ -518,7 +518,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha } } -func (d *DockerSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (d *DockerSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { d.t = t monitChan := make(chan *ContainerConfig) deleteChan := make(chan *ContainerConfig) @@ -589,11 +589,11 @@ func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types outChan <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) case <-readerTomb.Dying(): - //This case is to handle temporarily losing the connection to the docker socket - //The only known case currently is when using docker-socket-proxy (and maybe a docker daemon restart) + // This case is to handle temporarily losing the connection to the docker socket + // The only known case currently is when using docker-socket-proxy (and maybe a docker daemon restart) d.logger.Debugf("readerTomb dying for container %s, removing it from runningContainerState", container.Name) deleteChan <- container - //Also reset the Since to avoid re-reading logs + // Also reset the Since to avoid re-reading logs d.Config.Since = time.Now().UTC().Format(time.RFC3339) d.containerLogsOptions.Since = d.Config.Since return nil diff --git a/pkg/acquisition/modules/docker/docker_test.go b/pkg/acquisition/modules/docker/docker_test.go index e332569fb3a..e394c9cbe79 100644 --- a/pkg/acquisition/modules/docker/docker_test.go +++ b/pkg/acquisition/modules/docker/docker_test.go @@ -120,6 +120,7 @@ type mockDockerCli struct { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() log.SetOutput(os.Stdout) log.SetLevel(log.InfoLevel) log.Info("Test 'TestStreamingAcquisition'") @@ -185,7 +186,7 @@ container_name_regexp: readerTomb := &tomb.Tomb{} streamTomb := tomb.Tomb{} streamTomb.Go(func() error { - return dockerSource.StreamingAcquisition(out, &dockerTomb) + return dockerSource.StreamingAcquisition(ctx, out, &dockerTomb) }) readerTomb.Go(func() error { time.Sleep(1 * time.Second) @@ -245,7 +246,7 @@ func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, o for _, line := range data { startLineByte := make([]byte, 8) - binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream + binary.LittleEndian.PutUint32(startLineByte, 1) // stdout stream binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line))) ret += fmt.Sprintf("%s%s", startLineByte, line) } diff --git a/pkg/acquisition/modules/file/file.go b/pkg/acquisition/modules/file/file.go index 85b4c1b5b32..2d2df3ff4d4 100644 --- a/pkg/acquisition/modules/file/file.go +++ b/pkg/acquisition/modules/file/file.go @@ -3,6 +3,7 @@ package fileacquisition import ( "bufio" "compress/gzip" + "context" "errors" "fmt" "io" @@ -320,7 +321,7 @@ func (f *FileSource) CanRun() error { return nil } -func (f *FileSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *FileSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { f.logger.Debug("Starting live acquisition") t.Go(func() error { return f.monitorNewFiles(out, t) diff --git a/pkg/acquisition/modules/file/file_test.go b/pkg/acquisition/modules/file/file_test.go index 5d38552b3c5..3db0042ba2f 100644 --- a/pkg/acquisition/modules/file/file_test.go +++ b/pkg/acquisition/modules/file/file_test.go @@ -1,6 +1,7 @@ package fileacquisition_test import ( + "context" "fmt" "os" "runtime" @@ -243,6 +244,7 @@ filename: test_files/test_delete.log`, } func TestLiveAcquisition(t *testing.T) { + ctx := context.Background() permDeniedFile := "/etc/shadow" permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied" testPattern := "test_files/*.log" @@ -394,7 +396,7 @@ force_inotify: true`, testPattern), }() } - err = f.StreamingAcquisition(out, &tomb) + err = f.StreamingAcquisition(ctx, out, &tomb) cstest.RequireErrorContains(t, err, tc.expectedErr) if tc.expectedLines != 0 { diff --git a/pkg/acquisition/modules/journalctl/journalctl.go b/pkg/acquisition/modules/journalctl/journalctl.go index 1336fac4578..b9cda54a472 100644 --- a/pkg/acquisition/modules/journalctl/journalctl.go +++ b/pkg/acquisition/modules/journalctl/journalctl.go @@ -113,7 +113,7 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err return readLine(stdoutscanner, stdoutChan, errChan) }) t.Go(func() error { - //looks like journalctl closes stderr quite early, so ignore its status (but not its output) + // looks like journalctl closes stderr quite early, so ignore its status (but not its output) return readLine(stderrScanner, stderrChan, nil) }) @@ -122,7 +122,7 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err case <-t.Dying(): logger.Infof("journalctl datasource %s stopping", j.src) cancel() - cmd.Wait() //avoid zombie process + cmd.Wait() // avoid zombie process return nil case stdoutLine := <-stdoutChan: l := types.Line{} @@ -217,7 +217,7 @@ func (j *JournalCtlSource) ConfigureByDSN(dsn string, labels map[string]string, j.config.Labels = labels j.config.UniqueId = uuid - //format for the DSN is : journalctl://filters=FILTER1&filters=FILTER2 + // format for the DSN is : journalctl://filters=FILTER1&filters=FILTER2 if !strings.HasPrefix(dsn, "journalctl://") { return fmt.Errorf("invalid DSN %s for journalctl source, must start with journalctl://", dsn) } @@ -269,7 +269,7 @@ func (j *JournalCtlSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb return err } -func (j *JournalCtlSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (j *JournalCtlSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/journalctl/streaming") return j.runJournalCtl(out, t) diff --git a/pkg/acquisition/modules/journalctl/journalctl_test.go b/pkg/acquisition/modules/journalctl/journalctl_test.go index 53e2d0802ad..c416bb5d23e 100644 --- a/pkg/acquisition/modules/journalctl/journalctl_test.go +++ b/pkg/acquisition/modules/journalctl/journalctl_test.go @@ -1,6 +1,7 @@ package journalctlacquisition import ( + "context" "os" "os/exec" "path/filepath" @@ -187,6 +188,7 @@ journalctl_filter: } func TestStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -250,7 +252,7 @@ journalctl_filter: }() } - err = j.StreamingAcquisition(out, &tomb) + err = j.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if err != nil { diff --git a/pkg/acquisition/modules/kafka/kafka.go b/pkg/acquisition/modules/kafka/kafka.go index a0d7fc39bcc..9fd5fc2a035 100644 --- a/pkg/acquisition/modules/kafka/kafka.go +++ b/pkg/acquisition/modules/kafka/kafka.go @@ -23,9 +23,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - dataSourceName = "kafka" -) +var dataSourceName = "kafka" var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -204,7 +202,7 @@ func (k *KafkaSource) RunReader(out chan types.Event, t *tomb.Tomb) error { } } -func (k *KafkaSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KafkaSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { k.logger.Infof("start reader on brokers '%+v' with topic '%s'", k.Config.Brokers, k.Config.Topic) t.Go(func() error { diff --git a/pkg/acquisition/modules/kafka/kafka_test.go b/pkg/acquisition/modules/kafka/kafka_test.go index 7b467142cc9..d796166a6ca 100644 --- a/pkg/acquisition/modules/kafka/kafka_test.go +++ b/pkg/acquisition/modules/kafka/kafka_test.go @@ -80,9 +80,9 @@ group_id: crowdsec`, } } -func writeToKafka(w *kafka.Writer, logs []string) { +func writeToKafka(ctx context.Context, w *kafka.Writer, logs []string) { for idx, log := range logs { - err := w.WriteMessages(context.Background(), kafka.Message{ + err := w.WriteMessages(ctx, kafka.Message{ Key: []byte(strconv.Itoa(idx)), // create an arbitrary message payload for the value Value: []byte(log), @@ -128,6 +128,7 @@ func createTopic(topic string, broker string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -176,12 +177,12 @@ topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w, ts.logs) + go writeToKafka(ctx, w, ts.logs) READLOOP: for { select { @@ -199,6 +200,7 @@ topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) } func TestStreamingAcquisitionWithSSL(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -252,12 +254,12 @@ tls: tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w2, ts.logs) + go writeToKafka(ctx, w2, ts.logs) READLOOP: for { select { diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go index 0e6c1980fa9..ca3a847dbfb 100644 --- a/pkg/acquisition/modules/kinesis/kinesis.go +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -3,6 +3,7 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" "errors" "fmt" @@ -29,7 +30,7 @@ type KinesisConfiguration struct { configuration.DataSourceCommonCfg `yaml:",inline"` StreamName string `yaml:"stream_name"` StreamARN string `yaml:"stream_arn"` - UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` //Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords + UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` // Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords AwsProfile *string `yaml:"aws_profile"` AwsRegion string `yaml:"aws_region"` AwsEndpoint string `yaml:"aws_endpoint"` @@ -114,8 +115,8 @@ func (k *KinesisSource) newClient() error { func (k *KinesisSource) GetMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, linesReadShards} - } + func (k *KinesisSource) GetAggregMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, linesReadShards} } @@ -188,7 +189,6 @@ func (k *KinesisSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) e func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubscriptionLogEvent, error) { b := bytes.NewBuffer(record) r, err := gzip.NewReader(b) - if err != nil { k.logger.Error(err) return nil, err @@ -299,8 +299,8 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan var data []CloudwatchSubscriptionLogEvent var err error if k.Config.FromSubscription { - //The AWS docs says that the data is base64 encoded - //but apparently GetRecords decodes it for us ? + // The AWS docs says that the data is base64 encoded + // but apparently GetRecords decodes it for us ? data, err = k.decodeFromSubscription(record.Data) if err != nil { logger.Errorf("Cannot decode data: %s", err) @@ -335,9 +335,9 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan func (k *KinesisSource) ReadFromSubscription(reader kinesis.SubscribeToShardEventStreamReader, out chan types.Event, shardId string, streamName string) error { logger := k.logger.WithField("shard_id", shardId) - //ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately - //and we won't be able to start a new one if this is the first one started by the tomb - //TODO: look into parent shards to see if a shard is closed before starting to read it ? + // ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately + // and we won't be able to start a new one if this is the first one started by the tomb + // TODO: look into parent shards to see if a shard is closed before starting to read it ? time.Sleep(time.Second) for { select { @@ -420,7 +420,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { case <-t.Dying(): k.logger.Infof("Kinesis source is dying") k.shardReaderTomb.Kill(nil) - _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves err = k.DeregisterConsumer() if err != nil { return fmt.Errorf("cannot deregister consumer: %w", err) @@ -431,7 +431,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { if k.shardReaderTomb.Err() != nil { return k.shardReaderTomb.Err() } - //All goroutines have exited without error, so a resharding event, start again + // All goroutines have exited without error, so a resharding event, start again k.logger.Debugf("All reader goroutines have exited, resharding event or periodic resubscribe") continue } @@ -441,15 +441,17 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) error { logger := k.logger.WithField("shard", shardId) logger.Debugf("Starting to read shard") - sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ShardId: aws.String(shardId), + sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ + ShardId: aws.String(shardId), StreamName: &k.Config.StreamName, - ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest)}) + ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest), + }) if err != nil { logger.Errorf("Cannot get shard iterator: %s", err) return fmt.Errorf("cannot get shard iterator: %w", err) } it := sharIt.ShardIterator - //AWS recommends to wait for a second between calls to GetRecords for a given shard + // AWS recommends to wait for a second between calls to GetRecords for a given shard ticker := time.NewTicker(time.Second) for { select { @@ -460,7 +462,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro switch err.(type) { case *kinesis.ProvisionedThroughputExceededException: logger.Warn("Provisioned throughput exceeded") - //TODO: implement exponential backoff + // TODO: implement exponential backoff continue case *kinesis.ExpiredIteratorException: logger.Warn("Expired iterator") @@ -506,7 +508,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error case <-t.Dying(): k.logger.Info("kinesis source is dying") k.shardReaderTomb.Kill(nil) - _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves return nil case <-k.shardReaderTomb.Dying(): reason := k.shardReaderTomb.Err() @@ -520,7 +522,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error } } -func (k *KinesisSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming") if k.Config.UseEnhancedFanOut { diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go index 46e404aa49b..027cbde9240 100644 --- a/pkg/acquisition/modules/kinesis/kinesis_test.go +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -3,6 +3,7 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" "fmt" "net" @@ -60,8 +61,8 @@ func GenSubObject(i int) []byte { gz := gzip.NewWriter(&b) gz.Write(body) gz.Close() - //AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point - //localstack does not do it, so let's just write a raw gzipped stream + // AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point + // localstack does not do it, so let's just write a raw gzipped stream return b.Bytes() } @@ -99,10 +100,10 @@ func TestMain(m *testing.M) { os.Setenv("AWS_ACCESS_KEY_ID", "foobar") os.Setenv("AWS_SECRET_ACCESS_KEY", "foobar") - //delete_streams() - //create_streams() + // delete_streams() + // create_streams() code := m.Run() - //delete_streams() + // delete_streams() os.Exit(code) } @@ -149,6 +150,7 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, } func TestReadFromStream(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -176,11 +178,11 @@ stream_name: stream-1-shard`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, false) for i := range test.count { @@ -193,6 +195,7 @@ stream_name: stream-1-shard`, } func TestReadFromMultipleShards(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -220,11 +223,11 @@ stream_name: stream-2-shards`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, false) c := 0 @@ -239,6 +242,7 @@ stream_name: stream-2-shards`, } func TestFromSubscription(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -267,11 +271,11 @@ from_subscription: true`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, true) for i := range test.count { diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go index 8ba5b2d06e0..f979b044dcc 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go @@ -135,7 +135,7 @@ func (ka *KubernetesAuditSource) OneShotAcquisition(out chan types.Event, t *tom return errors.New("k8s-audit datasource does not support one-shot acquisition") } -func (ka *KubernetesAuditSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { ka.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/k8s-audit/live") @@ -164,7 +164,6 @@ func (ka *KubernetesAuditSource) Dump() interface{} { } func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.Request) { - if ka.metricsLevel != configuration.METRICS_NONE { requestCount.WithLabelValues(ka.addr).Inc() } diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index 020bd4c91a0..a086a756e4a 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -1,6 +1,7 @@ package kubernetesauditacquisition import ( + "context" "net/http/httptest" "strings" "testing" @@ -52,6 +53,7 @@ listen_addr: 0.0.0.0`, } func TestInvalidConfig(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -83,7 +85,7 @@ webhook_path: /k8s-audit`, err = f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) require.NoError(t, err) - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) time.Sleep(1 * time.Second) tb.Kill(nil) @@ -98,6 +100,7 @@ webhook_path: /k8s-audit`, } func TestHandler(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -257,14 +260,14 @@ webhook_path: /k8s-audit`, req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) f.webhookHandler(w, req) res := w.Result() assert.Equal(t, test.expectedStatusCode, res.StatusCode) - //time.Sleep(1 * time.Second) + // time.Sleep(1 * time.Second) require.NoError(t, err) tb.Kill(nil) diff --git a/pkg/acquisition/modules/loki/loki.go b/pkg/acquisition/modules/loki/loki.go index 15c454723ee..f867feeb84b 100644 --- a/pkg/acquisition/modules/loki/loki.go +++ b/pkg/acquisition/modules/loki/loki.go @@ -319,9 +319,9 @@ func (l *LokiSource) readOneEntry(entry lokiclient.Entry, labels map[string]stri } } -func (l *LokiSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (l *LokiSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { l.Client.SetTomb(t) - readyCtx, cancel := context.WithTimeout(context.Background(), l.Config.WaitForReady) + readyCtx, cancel := context.WithTimeout(ctx, l.Config.WaitForReady) defer cancel() err := l.Client.Ready(readyCtx) if err != nil { @@ -329,7 +329,7 @@ func (l *LokiSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) er } ll := l.logger.WithField("websocket_url", l.lokiWebsocket) t.Go(func() error { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() respChan := l.Client.QueryRange(ctx, true) if err != nil { diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index 2fd2b61e995..627200217f5 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -439,7 +439,7 @@ query: > t.Fatalf("Unexpected error : %s", err) } - err = lokiSource.StreamingAcquisition(out, &lokiTomb) + err = lokiSource.StreamingAcquisition(ctx, out, &lokiTomb) cstest.AssertErrorContains(t, err, ts.streamErr) if ts.streamErr != "" { @@ -449,7 +449,7 @@ query: > time.Sleep(time.Second * 2) // We need to give time to start reading from the WS readTomb := tomb.Tomb{} - readCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) + readCtx, cancel := context.WithTimeout(ctx, time.Second*10) count := 0 readTomb.Go(func() error { @@ -492,6 +492,7 @@ query: > } func TestStopStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -519,15 +520,13 @@ query: > lokiTomb := &tomb.Tomb{} - err = lokiSource.StreamingAcquisition(out, lokiTomb) + err = lokiSource.StreamingAcquisition(ctx, out, lokiTomb) if err != nil { t.Fatalf("Unexpected error : %s", err) } time.Sleep(time.Second * 2) - ctx := context.Background() - err = feedLoki(ctx, subLogger, 1, title) if err != nil { t.Fatalf("Unexpected error : %s", err) diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index a9835ab4974..ed1964edebf 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -38,7 +38,7 @@ type S3Configuration struct { AwsEndpoint string `yaml:"aws_endpoint"` BucketName string `yaml:"bucket_name"` Prefix string `yaml:"prefix"` - Key string `yaml:"-"` //Only for DSN acquisition + Key string `yaml:"-"` // Only for DSN acquisition PollingMethod string `yaml:"polling_method"` PollingInterval int `yaml:"polling_interval"` SQSName string `yaml:"sqs_name"` @@ -338,7 +338,7 @@ func (s *S3Source) sqsPoll() error { out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: aws.Int64(10), - WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ? + WaitTimeSeconds: aws.Int64(20), // Probably no need to make it configurable ? }) if err != nil { logger.Errorf("Error while polling SQS: %s", err) @@ -353,7 +353,7 @@ func (s *S3Source) sqsPoll() error { bucket, key, err := s.extractBucketAndPrefix(message.Body) if err != nil { logger.Errorf("Error while parsing SQS message: %s", err) - //Always delete the message to avoid infinite loop + // Always delete the message to avoid infinite loop _, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), ReceiptHandle: message.ReceiptHandle, @@ -379,7 +379,7 @@ func (s *S3Source) sqsPoll() error { } func (s *S3Source) readFile(bucket string, key string) error { - //TODO: Handle SSE-C + // TODO: Handle SSE-C var scanner *bufio.Scanner logger := s.logger.WithFields(log.Fields{ @@ -392,14 +392,13 @@ func (s *S3Source) readFile(bucket string, key string) error { Bucket: aws.String(bucket), Key: aws.String(key), }) - if err != nil { return fmt.Errorf("failed to get object %s/%s: %w", bucket, key, err) } defer output.Body.Close() if strings.HasSuffix(key, ".gz") { - //This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) + // This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) header := make([]byte, 2) _, err := output.Body.Read(header) if err != nil { @@ -613,7 +612,7 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * pathParts := strings.Split(args[0], "/") s.logger.Debugf("pathParts: %v", pathParts) - //FIXME: handle s3://bucket/ + // FIXME: handle s3://bucket/ if len(pathParts) == 1 { s.Config.BucketName = pathParts[0] s.Config.Prefix = "" @@ -656,7 +655,7 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return err } } else { - //No key, get everything in the bucket based on the prefix + // No key, get everything in the bucket based on the prefix objects, err := s.getBucketContent() if err != nil { return err @@ -672,11 +671,11 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return nil } -func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *S3Source) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { s.t = t s.out = out - s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(context.Background()) + s.readerChan = make(chan S3Object, 100) // FIXME: does this needs to be buffered? + s.ctx, s.cancel = context.WithCancel(ctx) s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) t.Go(func() error { s.readManager() diff --git a/pkg/acquisition/modules/s3/s3_test.go b/pkg/acquisition/modules/s3/s3_test.go index 93e166dfec5..05a974517a0 100644 --- a/pkg/acquisition/modules/s3/s3_test.go +++ b/pkg/acquisition/modules/s3/s3_test.go @@ -272,6 +272,7 @@ func TestDSNAcquis(t *testing.T) { } func TestListPolling(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -331,7 +332,7 @@ prefix: foo/ } }() - err = f.StreamingAcquisition(out, &tb) + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -348,6 +349,7 @@ prefix: foo/ } func TestSQSPoll(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -411,7 +413,7 @@ sqs_name: test } }() - err = f.StreamingAcquisition(out, &tb) + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } diff --git a/pkg/acquisition/modules/syslog/syslog.go b/pkg/acquisition/modules/syslog/syslog.go index 06c32e62f77..5315096fb9b 100644 --- a/pkg/acquisition/modules/syslog/syslog.go +++ b/pkg/acquisition/modules/syslog/syslog.go @@ -1,6 +1,7 @@ package syslogacquisition import ( + "context" "errors" "fmt" "net" @@ -105,7 +106,7 @@ func (s *SyslogSource) UnmarshalConfig(yamlConfig []byte) error { } if s.config.Addr == "" { - s.config.Addr = "127.0.0.1" //do we want a usable or secure default ? + s.config.Addr = "127.0.0.1" // do we want a usable or secure default ? } if s.config.Port == 0 { s.config.Port = 514 @@ -135,7 +136,7 @@ func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe return nil } -func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *SyslogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { c := make(chan syslogserver.SyslogMessage) s.server = &syslogserver.SyslogServer{Logger: s.logger.WithField("syslog", "internal"), MaxMessageLen: s.config.MaxMessageLen} s.server.SetChannel(c) @@ -152,7 +153,8 @@ func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) } func (s *SyslogSource) buildLogFromSyslog(ts time.Time, hostname string, - appname string, pid string, msg string) string { + appname string, pid string, msg string, +) string { ret := "" if !ts.IsZero() { ret += ts.Format("Jan 2 15:04:05") @@ -178,7 +180,6 @@ func (s *SyslogSource) buildLogFromSyslog(ts time.Time, hostname string, ret += msg } return ret - } func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c chan syslogserver.SyslogMessage) error { diff --git a/pkg/acquisition/modules/syslog/syslog_test.go b/pkg/acquisition/modules/syslog/syslog_test.go index 1750f375138..57fa3e8747b 100644 --- a/pkg/acquisition/modules/syslog/syslog_test.go +++ b/pkg/acquisition/modules/syslog/syslog_test.go @@ -1,6 +1,7 @@ package syslogacquisition import ( + "context" "fmt" "net" "runtime" @@ -80,6 +81,7 @@ func writeToSyslog(logs []string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -100,8 +102,10 @@ listen_addr: 127.0.0.1`, listen_port: 4242 listen_addr: 127.0.0.1`, expectedLines: 2, - logs: []string{`<13>1 2021-05-18T11:58:40.828081+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, - `<13>1 2021-05-18T12:12:37.560695+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla2[foobar]`}, + logs: []string{ + `<13>1 2021-05-18T11:58:40.828081+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, + `<13>1 2021-05-18T12:12:37.560695+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla2[foobar]`, + }, }, { name: "RFC3164", @@ -109,10 +113,12 @@ listen_addr: 127.0.0.1`, listen_port: 4242 listen_addr: 127.0.0.1`, expectedLines: 3, - logs: []string{`<13>May 18 12:37:56 mantis sshd[49340]: blabla2[foobar]`, + logs: []string{ + `<13>May 18 12:37:56 mantis sshd[49340]: blabla2[foobar]`, `<13>May 18 12:37:56 mantis sshd[49340]: blabla2`, `<13>May 18 12:37:56 mantis sshd: blabla2`, - `<13>May 18 12:37:56 mantis sshd`}, + `<13>May 18 12:37:56 mantis sshd`, + }, }, } if runtime.GOOS != "windows" { @@ -139,7 +145,7 @@ listen_addr: 127.0.0.1`, } tomb := tomb.Tomb{} out := make(chan types.Event) - err = s.StreamingAcquisition(out, &tomb) + err = s.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if ts.expectedErr != "" { return diff --git a/pkg/acquisition/modules/wineventlog/wineventlog.go b/pkg/acquisition/modules/wineventlog/wineventlog.go index 44035d0a708..6d522d8d8cb 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "errors" "github.com/prometheus/client_golang/prometheus" @@ -59,7 +60,7 @@ func (w *WinEventLogSource) CanRun() error { return errors.New("windows event log acquisition is only supported on Windows") } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { return nil } diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_test.go index 2ea0e365be5..ae6cb776909 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_test.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "runtime" "testing" "time" @@ -129,6 +130,7 @@ event_level: bla`, } func TestLiveAcquisition(t *testing.T) { + ctx := context.Background() if runtime.GOOS != "windows" { t.Skip("Skipping test on non-windows OS") } @@ -190,7 +192,7 @@ event_ids: c := make(chan types.Event) f := WinEventLogSource{} f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) - f.StreamingAcquisition(c, to) + f.StreamingAcquisition(ctx, c, to) time.Sleep(time.Second) lines := test.expectedLines go func() { diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go index 4f2384d71db..087c20eb70e 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go @@ -1,6 +1,7 @@ package wineventlogacquisition import ( + "context" "encoding/xml" "errors" "fmt" @@ -325,7 +326,7 @@ func (w *WinEventLogSource) CanRun() error { return nil } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/wineventlog/streaming") return w.getEvents(out, t)