diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index 460e8ab4328..c44d71d2093 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -169,7 +169,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H log.Info("Starting processing data") - if err := acquisition.StartAcquisition(dataSources, inputLineChan, &acquisTomb); err != nil { + if err := acquisition.StartAcquisition(context.TODO(), dataSources, inputLineChan, &acquisTomb); err != nil { return fmt.Errorf("starting acquisition error: %w", err) } diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 4a5226a2981..003a8cfd085 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -1,6 +1,7 @@ package acquisition import ( + "context" "errors" "fmt" "io" @@ -47,7 +48,7 @@ type DataSource interface { GetMode() string // Get the mode (TAIL, CAT or SERVER) GetName() string // Get the name of the module OneShotAcquisition(chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) - StreamingAcquisition(chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) + StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) GetUuid() string // Get the unique identifier of the datasource Dump() interface{} @@ -375,7 +376,7 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo } } -func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { +func StartAcquisition(ctx context.Context, sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { // Don't wait if we have no sources, as it will hang forever if len(sources) == 0 { return nil @@ -405,7 +406,7 @@ func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb }) } if subsrc.GetMode() == configuration.TAIL_MODE { - err = subsrc.StreamingAcquisition(outChan, AcquisTomb) + err = subsrc.StreamingAcquisition(ctx, outChan, AcquisTomb) } else { err = subsrc.OneShotAcquisition(outChan, AcquisTomb) } diff --git a/pkg/acquisition/acquisition_test.go b/pkg/acquisition/acquisition_test.go index e39199f9cdb..ded784afb46 100644 --- a/pkg/acquisition/acquisition_test.go +++ b/pkg/acquisition/acquisition_test.go @@ -1,6 +1,7 @@ package acquisition import ( + "context" "errors" "fmt" "strings" @@ -58,7 +59,7 @@ func (f *MockSource) Configure(cfg []byte, logger *log.Entry, metricsLevel int) } func (f *MockSource) GetMode() string { return f.Mode } func (f *MockSource) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } +func (f *MockSource) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { return nil } func (f *MockSource) CanRun() error { return nil } func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } @@ -327,7 +328,7 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro return nil } -func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { +func (f *MockCat) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { return errors.New("can't run in tail") } func (f *MockCat) CanRun() error { return nil } @@ -366,7 +367,7 @@ func (f *MockTail) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) err return errors.New("can't run in cat mode") } -func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *MockTail) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { for range 10 { evt := types.Event{} evt.Line.Src = "test" @@ -389,6 +390,7 @@ func (f *MockTail) GetUuid() string { return "" } // func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { func TestStartAcquisitionCat(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockCat{}, } @@ -396,7 +398,7 @@ func TestStartAcquisitionCat(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -416,6 +418,7 @@ READLOOP: } func TestStartAcquisitionTail(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTail{}, } @@ -423,7 +426,7 @@ func TestStartAcquisitionTail(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -450,7 +453,7 @@ type MockTailError struct { MockTail } -func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *MockTailError) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { for range 10 { evt := types.Event{} evt.Line.Src = "test" @@ -463,6 +466,7 @@ func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) } func TestStartAcquisitionTailError(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTailError{}, } @@ -470,7 +474,7 @@ func TestStartAcquisitionTailError(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { t.Errorf("expected error, got '%s'", err) } }() @@ -503,7 +507,7 @@ func (f *MockSourceByDSN) Configure(cfg []byte, logger *log.Entry, metricsLevel } func (f *MockSourceByDSN) GetMode() string { return f.Mode } func (f *MockSourceByDSN) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } +func (f *MockSourceByDSN) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { return nil } func (f *MockSourceByDSN) CanRun() error { return nil } func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index 8a93326c7e3..d47062a21df 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -241,7 +241,7 @@ func (w *AppsecSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er return errors.New("AppSec datasource does not support command line acquisition") } -func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { w.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/appsec/live") diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch.go b/pkg/acquisition/modules/cloudwatch/cloudwatch.go index d6f33b68050..10151fbd8a9 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch.go @@ -243,7 +243,7 @@ func (cw *CloudwatchSource) newClient() error { return nil } -func (cw *CloudwatchSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (cw *CloudwatchSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { cw.t = t monitChan := make(chan LogStreamTailConfig) t.Go(func() error { diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go index bab7593f26f..1f8055b6b88 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go @@ -1,6 +1,7 @@ package cloudwatchacquisition import ( + "context" "errors" "fmt" "net" @@ -74,6 +75,7 @@ func TestMain(m *testing.M) { } func TestWatchLogGroupForStreams(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -447,7 +449,7 @@ stream_name: test_stream`), dbgLogger.Infof("running StreamingAcquisition") actmb := tomb.Tomb{} actmb.Go(func() error { - err := cw.StreamingAcquisition(out, &actmb) + err := cw.StreamingAcquisition(ctx, out, &actmb) dbgLogger.Infof("acquis done") cstest.RequireErrorContains(t, err, tc.expectedStartErr) return nil @@ -513,6 +515,7 @@ stream_name: test_stream`), } func TestConfiguration(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -571,7 +574,7 @@ stream_name: test_stream`), switch cw.GetMode() { case "tail": - err = cw.StreamingAcquisition(out, &tmb) + err = cw.StreamingAcquisition(ctx, out, &tmb) case "cat": err = cw.OneShotAcquisition(out, &tmb) } diff --git a/pkg/acquisition/modules/docker/docker.go b/pkg/acquisition/modules/docker/docker.go index 44fee0a99a2..67504f255d0 100644 --- a/pkg/acquisition/modules/docker/docker.go +++ b/pkg/acquisition/modules/docker/docker.go @@ -518,7 +518,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha } } -func (d *DockerSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (d *DockerSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { d.t = t monitChan := make(chan *ContainerConfig) deleteChan := make(chan *ContainerConfig) diff --git a/pkg/acquisition/modules/docker/docker_test.go b/pkg/acquisition/modules/docker/docker_test.go index e332569fb3a..05ec81d81b9 100644 --- a/pkg/acquisition/modules/docker/docker_test.go +++ b/pkg/acquisition/modules/docker/docker_test.go @@ -120,6 +120,7 @@ type mockDockerCli struct { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() log.SetOutput(os.Stdout) log.SetLevel(log.InfoLevel) log.Info("Test 'TestStreamingAcquisition'") @@ -185,7 +186,7 @@ container_name_regexp: readerTomb := &tomb.Tomb{} streamTomb := tomb.Tomb{} streamTomb.Go(func() error { - return dockerSource.StreamingAcquisition(out, &dockerTomb) + return dockerSource.StreamingAcquisition(ctx, out, &dockerTomb) }) readerTomb.Go(func() error { time.Sleep(1 * time.Second) diff --git a/pkg/acquisition/modules/file/file.go b/pkg/acquisition/modules/file/file.go index 85b4c1b5b32..2d2df3ff4d4 100644 --- a/pkg/acquisition/modules/file/file.go +++ b/pkg/acquisition/modules/file/file.go @@ -3,6 +3,7 @@ package fileacquisition import ( "bufio" "compress/gzip" + "context" "errors" "fmt" "io" @@ -320,7 +321,7 @@ func (f *FileSource) CanRun() error { return nil } -func (f *FileSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *FileSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { f.logger.Debug("Starting live acquisition") t.Go(func() error { return f.monitorNewFiles(out, t) diff --git a/pkg/acquisition/modules/file/file_test.go b/pkg/acquisition/modules/file/file_test.go index 5d38552b3c5..3db0042ba2f 100644 --- a/pkg/acquisition/modules/file/file_test.go +++ b/pkg/acquisition/modules/file/file_test.go @@ -1,6 +1,7 @@ package fileacquisition_test import ( + "context" "fmt" "os" "runtime" @@ -243,6 +244,7 @@ filename: test_files/test_delete.log`, } func TestLiveAcquisition(t *testing.T) { + ctx := context.Background() permDeniedFile := "/etc/shadow" permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied" testPattern := "test_files/*.log" @@ -394,7 +396,7 @@ force_inotify: true`, testPattern), }() } - err = f.StreamingAcquisition(out, &tomb) + err = f.StreamingAcquisition(ctx, out, &tomb) cstest.RequireErrorContains(t, err, tc.expectedErr) if tc.expectedLines != 0 { diff --git a/pkg/acquisition/modules/journalctl/journalctl.go b/pkg/acquisition/modules/journalctl/journalctl.go index 1336fac4578..ac9623e30cb 100644 --- a/pkg/acquisition/modules/journalctl/journalctl.go +++ b/pkg/acquisition/modules/journalctl/journalctl.go @@ -269,7 +269,7 @@ func (j *JournalCtlSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb return err } -func (j *JournalCtlSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (j *JournalCtlSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/journalctl/streaming") return j.runJournalCtl(out, t) diff --git a/pkg/acquisition/modules/journalctl/journalctl_test.go b/pkg/acquisition/modules/journalctl/journalctl_test.go index 53e2d0802ad..c416bb5d23e 100644 --- a/pkg/acquisition/modules/journalctl/journalctl_test.go +++ b/pkg/acquisition/modules/journalctl/journalctl_test.go @@ -1,6 +1,7 @@ package journalctlacquisition import ( + "context" "os" "os/exec" "path/filepath" @@ -187,6 +188,7 @@ journalctl_filter: } func TestStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -250,7 +252,7 @@ journalctl_filter: }() } - err = j.StreamingAcquisition(out, &tomb) + err = j.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if err != nil { diff --git a/pkg/acquisition/modules/kafka/kafka.go b/pkg/acquisition/modules/kafka/kafka.go index a0d7fc39bcc..c11b4b34173 100644 --- a/pkg/acquisition/modules/kafka/kafka.go +++ b/pkg/acquisition/modules/kafka/kafka.go @@ -204,7 +204,7 @@ func (k *KafkaSource) RunReader(out chan types.Event, t *tomb.Tomb) error { } } -func (k *KafkaSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KafkaSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { k.logger.Infof("start reader on brokers '%+v' with topic '%s'", k.Config.Brokers, k.Config.Topic) t.Go(func() error { diff --git a/pkg/acquisition/modules/kafka/kafka_test.go b/pkg/acquisition/modules/kafka/kafka_test.go index 7b467142cc9..d796166a6ca 100644 --- a/pkg/acquisition/modules/kafka/kafka_test.go +++ b/pkg/acquisition/modules/kafka/kafka_test.go @@ -80,9 +80,9 @@ group_id: crowdsec`, } } -func writeToKafka(w *kafka.Writer, logs []string) { +func writeToKafka(ctx context.Context, w *kafka.Writer, logs []string) { for idx, log := range logs { - err := w.WriteMessages(context.Background(), kafka.Message{ + err := w.WriteMessages(ctx, kafka.Message{ Key: []byte(strconv.Itoa(idx)), // create an arbitrary message payload for the value Value: []byte(log), @@ -128,6 +128,7 @@ func createTopic(topic string, broker string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -176,12 +177,12 @@ topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w, ts.logs) + go writeToKafka(ctx, w, ts.logs) READLOOP: for { select { @@ -199,6 +200,7 @@ topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) } func TestStreamingAcquisitionWithSSL(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -252,12 +254,12 @@ tls: tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w2, ts.logs) + go writeToKafka(ctx, w2, ts.logs) READLOOP: for { select { diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go index 0e6c1980fa9..12c8410c2e9 100644 --- a/pkg/acquisition/modules/kinesis/kinesis.go +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -3,6 +3,7 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" "errors" "fmt" @@ -520,7 +521,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error } } -func (k *KinesisSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming") if k.Config.UseEnhancedFanOut { diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go index 46e404aa49b..815f4a7c8ce 100644 --- a/pkg/acquisition/modules/kinesis/kinesis_test.go +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -3,6 +3,7 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" "fmt" "net" @@ -149,6 +150,7 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, } func TestReadFromStream(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -176,7 +178,7 @@ stream_name: stream-1-shard`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } @@ -193,6 +195,7 @@ stream_name: stream-1-shard`, } func TestReadFromMultipleShards(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -220,7 +223,7 @@ stream_name: stream-2-shards`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } @@ -239,6 +242,7 @@ stream_name: stream-2-shards`, } func TestFromSubscription(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -267,7 +271,7 @@ from_subscription: true`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go index 8ba5b2d06e0..27cbf1bc1c0 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go @@ -135,7 +135,7 @@ func (ka *KubernetesAuditSource) OneShotAcquisition(out chan types.Event, t *tom return errors.New("k8s-audit datasource does not support one-shot acquisition") } -func (ka *KubernetesAuditSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { ka.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/k8s-audit/live") diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index 020bd4c91a0..bc48ee6ae70 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -1,6 +1,7 @@ package kubernetesauditacquisition import ( + "context" "net/http/httptest" "strings" "testing" @@ -52,6 +53,7 @@ listen_addr: 0.0.0.0`, } func TestInvalidConfig(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -83,7 +85,7 @@ webhook_path: /k8s-audit`, err = f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) require.NoError(t, err) - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) time.Sleep(1 * time.Second) tb.Kill(nil) @@ -98,6 +100,7 @@ webhook_path: /k8s-audit`, } func TestHandler(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -257,7 +260,7 @@ webhook_path: /k8s-audit`, req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) f.webhookHandler(w, req) diff --git a/pkg/acquisition/modules/loki/loki.go b/pkg/acquisition/modules/loki/loki.go index 15c454723ee..f867feeb84b 100644 --- a/pkg/acquisition/modules/loki/loki.go +++ b/pkg/acquisition/modules/loki/loki.go @@ -319,9 +319,9 @@ func (l *LokiSource) readOneEntry(entry lokiclient.Entry, labels map[string]stri } } -func (l *LokiSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (l *LokiSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { l.Client.SetTomb(t) - readyCtx, cancel := context.WithTimeout(context.Background(), l.Config.WaitForReady) + readyCtx, cancel := context.WithTimeout(ctx, l.Config.WaitForReady) defer cancel() err := l.Client.Ready(readyCtx) if err != nil { @@ -329,7 +329,7 @@ func (l *LokiSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) er } ll := l.logger.WithField("websocket_url", l.lokiWebsocket) t.Go(func() error { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() respChan := l.Client.QueryRange(ctx, true) if err != nil { diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index 2fd2b61e995..627200217f5 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -439,7 +439,7 @@ query: > t.Fatalf("Unexpected error : %s", err) } - err = lokiSource.StreamingAcquisition(out, &lokiTomb) + err = lokiSource.StreamingAcquisition(ctx, out, &lokiTomb) cstest.AssertErrorContains(t, err, ts.streamErr) if ts.streamErr != "" { @@ -449,7 +449,7 @@ query: > time.Sleep(time.Second * 2) // We need to give time to start reading from the WS readTomb := tomb.Tomb{} - readCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) + readCtx, cancel := context.WithTimeout(ctx, time.Second*10) count := 0 readTomb.Go(func() error { @@ -492,6 +492,7 @@ query: > } func TestStopStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -519,15 +520,13 @@ query: > lokiTomb := &tomb.Tomb{} - err = lokiSource.StreamingAcquisition(out, lokiTomb) + err = lokiSource.StreamingAcquisition(ctx, out, lokiTomb) if err != nil { t.Fatalf("Unexpected error : %s", err) } time.Sleep(time.Second * 2) - ctx := context.Background() - err = feedLoki(ctx, subLogger, 1, title) if err != nil { t.Fatalf("Unexpected error : %s", err) diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index a9835ab4974..65bfcdd66ab 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -672,11 +672,11 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return nil } -func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *S3Source) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { s.t = t s.out = out s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(context.Background()) + s.ctx, s.cancel = context.WithCancel(ctx) s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) t.Go(func() error { s.readManager() diff --git a/pkg/acquisition/modules/s3/s3_test.go b/pkg/acquisition/modules/s3/s3_test.go index 93e166dfec5..05a974517a0 100644 --- a/pkg/acquisition/modules/s3/s3_test.go +++ b/pkg/acquisition/modules/s3/s3_test.go @@ -272,6 +272,7 @@ func TestDSNAcquis(t *testing.T) { } func TestListPolling(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -331,7 +332,7 @@ prefix: foo/ } }() - err = f.StreamingAcquisition(out, &tb) + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -348,6 +349,7 @@ prefix: foo/ } func TestSQSPoll(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -411,7 +413,7 @@ sqs_name: test } }() - err = f.StreamingAcquisition(out, &tb) + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } diff --git a/pkg/acquisition/modules/syslog/syslog.go b/pkg/acquisition/modules/syslog/syslog.go index 06c32e62f77..54147f6d070 100644 --- a/pkg/acquisition/modules/syslog/syslog.go +++ b/pkg/acquisition/modules/syslog/syslog.go @@ -1,6 +1,7 @@ package syslogacquisition import ( + "context" "errors" "fmt" "net" @@ -135,7 +136,7 @@ func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe return nil } -func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *SyslogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { c := make(chan syslogserver.SyslogMessage) s.server = &syslogserver.SyslogServer{Logger: s.logger.WithField("syslog", "internal"), MaxMessageLen: s.config.MaxMessageLen} s.server.SetChannel(c) diff --git a/pkg/acquisition/modules/syslog/syslog_test.go b/pkg/acquisition/modules/syslog/syslog_test.go index 1750f375138..b771b5b1402 100644 --- a/pkg/acquisition/modules/syslog/syslog_test.go +++ b/pkg/acquisition/modules/syslog/syslog_test.go @@ -1,6 +1,7 @@ package syslogacquisition import ( + "context" "fmt" "net" "runtime" @@ -80,6 +81,7 @@ func writeToSyslog(logs []string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -139,7 +141,7 @@ listen_addr: 127.0.0.1`, } tomb := tomb.Tomb{} out := make(chan types.Event) - err = s.StreamingAcquisition(out, &tomb) + err = s.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if ts.expectedErr != "" { return diff --git a/pkg/acquisition/modules/wineventlog/wineventlog.go b/pkg/acquisition/modules/wineventlog/wineventlog.go index 44035d0a708..6d522d8d8cb 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "errors" "github.com/prometheus/client_golang/prometheus" @@ -59,7 +60,7 @@ func (w *WinEventLogSource) CanRun() error { return errors.New("windows event log acquisition is only supported on Windows") } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { return nil } diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_test.go index 2ea0e365be5..ae6cb776909 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_test.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "runtime" "testing" "time" @@ -129,6 +130,7 @@ event_level: bla`, } func TestLiveAcquisition(t *testing.T) { + ctx := context.Background() if runtime.GOOS != "windows" { t.Skip("Skipping test on non-windows OS") } @@ -190,7 +192,7 @@ event_ids: c := make(chan types.Event) f := WinEventLogSource{} f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) - f.StreamingAcquisition(c, to) + f.StreamingAcquisition(ctx, c, to) time.Sleep(time.Second) lines := test.expectedLines go func() { diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go index 4f2384d71db..087c20eb70e 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go @@ -1,6 +1,7 @@ package wineventlogacquisition import ( + "context" "encoding/xml" "errors" "fmt" @@ -325,7 +326,7 @@ func (w *WinEventLogSource) CanRun() error { return nil } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/wineventlog/streaming") return w.getEvents(out, t)