From 82dc6155c082b6e711834effda011955eed02001 Mon Sep 17 00:00:00 2001 From: Christian Kruse Date: Fri, 6 Oct 2023 11:29:00 -0700 Subject: [PATCH] Clean up asset fetching logic and add basic tests Signed-off-by: Christian Kruse --- pkg/receiver/jobreceiver/asset/fetcher.go | 14 +- .../jobreceiver/asset/fetcher_test.go | 194 ++++++++++++++++-- .../jobreceiver/asset/integration_test.go | 63 ++++++ pkg/receiver/jobreceiver/asset/manager.go | 65 ++---- .../jobreceiver/asset/verifier_test.go | 11 +- 5 files changed, 277 insertions(+), 70 deletions(-) create mode 100644 pkg/receiver/jobreceiver/asset/integration_test.go diff --git a/pkg/receiver/jobreceiver/asset/fetcher.go b/pkg/receiver/jobreceiver/asset/fetcher.go index 88e85b5ea4..51f2c521a4 100644 --- a/pkg/receiver/jobreceiver/asset/fetcher.go +++ b/pkg/receiver/jobreceiver/asset/fetcher.go @@ -21,8 +21,9 @@ type Fetcher interface { // An HTTPFetcher fetches the contents of files at a given URL. type httpFetcher struct { - client *http.Client - logger *zap.SugaredLogger + client *http.Client + logger *zap.SugaredLogger + makeBackoff func() backoff.BackOff } // NewFetcher creates a new HTTP based Fetcher. @@ -30,8 +31,9 @@ type httpFetcher struct { // Uses an exponential backoff to retry failed requests. func NewFetcher(log *zap.SugaredLogger, client *http.Client) Fetcher { return &httpFetcher{ - client: client, - logger: log, + client: client, + logger: log, + makeBackoff: func() backoff.BackOff { return backoff.NewExponentialBackOff() }, } } @@ -40,7 +42,7 @@ func NewFetcher(log *zap.SugaredLogger, client *http.Client) Fetcher { func (h *httpFetcher) Fetch(ctx context.Context, url string) (*os.File, error) { var fetchErr error var attempts int - b := backoff.NewExponentialBackOff() + b := h.makeBackoff() for { duration := b.NextBackOff() if duration == backoff.Stop { @@ -56,7 +58,7 @@ func (h *httpFetcher) Fetch(ctx context.Context, url string) (*os.File, error) { h.logger.Errorf("retrying failed asset fetch for %s: %s", url, fetchErr) } - attempts++ + attempts = attempts + 1 out, err := h.tryFetch(ctx, url) if err != nil { fetchErr = err diff --git a/pkg/receiver/jobreceiver/asset/fetcher_test.go b/pkg/receiver/jobreceiver/asset/fetcher_test.go index c863dea709..6ab5583146 100644 --- a/pkg/receiver/jobreceiver/asset/fetcher_test.go +++ b/pkg/receiver/jobreceiver/asset/fetcher_test.go @@ -2,13 +2,15 @@ package asset import ( "context" + "io" "net/http" "net/http/httptest" "net/url" "os" - "strings" "testing" + "time" + "github.com/cenkalti/backoff/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" @@ -16,6 +18,7 @@ import ( func TestFetcher(t *testing.T) { t.Parallel() + srv := httptest.NewTLSServer(http.FileServer(http.Dir(fixturePath("")))) defer srv.Close() client := srv.Client() @@ -25,25 +28,180 @@ func TestFetcher(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - helloTGZURL, err := url.JoinPath(srv.URL, "helloworld.tar.gz") - require.NoError(t, err) - f, err := fetcher.Fetch(ctx, helloTGZURL) - require.NoError(t, err) + t.Run("can fetch tar gz file", func(t *testing.T) { + helloTGZURL, err := url.JoinPath(srv.URL, "helloworld.tar.gz") + require.NoError(t, err) + f, err := fetcher.Fetch(ctx, helloTGZURL) + require.NoError(t, err) - tarGZShaBytes, err := os.ReadFile(fixturePath("helloworld.tar.gz.sha512")) - require.NoError(t, err) - tarGZSHA := strings.TrimSpace(string(tarGZShaBytes)) - err = new(sha512Verifier).Verify(f, tarGZSHA) - assert.NoError(t, err) + tarGZSHA := ExtractSHA(t, "helloworld.tar.gz.sha512") + err = new(sha512Verifier).Verify(f, tarGZSHA) + assert.NoError(t, err) + }) - helloTURL, err := url.JoinPath(srv.URL, "helloworld.tar") - require.NoError(t, err) - f, err = fetcher.Fetch(ctx, helloTURL) - require.NoError(t, err) + t.Run("can fetch tar file", func(t *testing.T) { + helloTURL, err := url.JoinPath(srv.URL, "helloworld.tar") + require.NoError(t, err) + f, err := fetcher.Fetch(ctx, helloTURL) + require.NoError(t, err) + + tarSHA := ExtractSHA(t, "helloworld.tar.sha512") + err = new(sha512Verifier).Verify(f, tarSHA) + assert.NoError(t, err) + }) +} + +func TestFetcherBackOff(t *testing.T) { + t.Parallel() + + observedIntervals := make(chan time.Duration, 16) + respondOK := make(chan []byte, 1) + srv := httptest.NewTLSServer(&intervalObserver{ + Response: respondOK, + Out: observedIntervals, + }) + defer srv.Close() + fetcher := NewFetcher(zap.NewNop().Sugar(), srv.Client()) + fetcher.(*httpFetcher).makeBackoff = func() backoff.BackOff { + return backoff.NewConstantBackOff(time.Millisecond * 150) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fetchDone := make(chan struct{}) + var fetchResponse *os.File + var fetchErr error + go func() { + fetchResponse, fetchErr = fetcher.Fetch(ctx, srv.URL) + close(fetchDone) + }() + + intervals := make([]time.Duration, 0, 3) + assert.Eventually(t, func() bool { + select { + case interval := <-observedIntervals: + intervals = append(intervals, interval) + default: + } + return len(intervals) >= 3 + }, time.Second*5, time.Millisecond*150, "expected fetcher to retry repeatedly") + + // retried according to backoff schedule with .3 to 10x tolerance + for _, interval := range intervals { + assert.Less(t, time.Millisecond*50, interval) + assert.Greater(t, time.Millisecond*1500, interval) + } + + // respond to fetcher sucessfully with payload + respondOK <- []byte("sensu") + + assert.Eventually(t, func() bool { + select { + case <-fetchDone: + return true + default: + return false + } + }, time.Second*5, time.Millisecond*150, "expected fetcher to return") - tarShaBytes, err := os.ReadFile(fixturePath("helloworld.tar.sha512")) + require.NoError(t, fetchErr) + defer func() { assert.NoError(t, fetchResponse.Close()) }() + actualText, err := io.ReadAll(fetchResponse) require.NoError(t, err) - tarSHA := strings.TrimSpace(string(tarShaBytes)) - err = new(sha512Verifier).Verify(f, tarSHA) - assert.NoError(t, err) + assert.Equal(t, "sensu", string(actualText)) +} + +func TestFetcherBackOffExpires(t *testing.T) { + t.Parallel() + + observedIntervals := make(chan time.Duration, 16) + srv := httptest.NewTLSServer(&intervalObserver{ + Response: make(chan []byte, 1), + Out: observedIntervals, + }) + defer srv.Close() + fetcher := NewFetcher(zap.NewNop().Sugar(), srv.Client()) + // set backoff max elapsed time to .5s + fetcher.(*httpFetcher).makeBackoff = func() backoff.BackOff { + b := backoff.NewExponentialBackOff() + b.InitialInterval = 50 * time.Millisecond + b.Multiplier = 1 + b.RandomizationFactor = 0.1 + b.MaxElapsedTime = 2 * time.Second + return b + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fetchDone := make(chan struct{}) + var fetchResponse *os.File + var fetchErr error + go func() { + fetchResponse, fetchErr = fetcher.Fetch(ctx, srv.URL) + close(fetchDone) + }() + + atLeastOneRetry := make(chan struct{}) + go func() { + <-observedIntervals + close(atLeastOneRetry) + for { + select { + case <-observedIntervals: + case <-ctx.Done(): + return + } + } + }() + assert.Eventually(t, func() bool { + select { + case <-atLeastOneRetry: + return true + default: + return false + } + }, time.Second*5, time.Millisecond*150, "expected at least one retry") + + assert.Eventually(t, func() bool { + select { + case <-fetchDone: + return true + default: + return false + } + }, time.Second*2, time.Millisecond*50, "expected fetcher to return") + + assert.Error(t, fetchErr) + if fetchResponse != nil { + assert.NoError(t, fetchResponse.Close()) + } +} + +type intervalObserver struct { + Last time.Time + + Response <-chan []byte + Out chan<- time.Duration +} + +var _ http.Handler = (*intervalObserver)(nil) + +func (i *intervalObserver) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + prev := i.Last + i.Last = time.Now() + + select { + case resp := <-i.Response: + _, err := rw.Write(resp) + if err != nil { + panic(err) + } + default: + if !prev.IsZero() { + i.Out <- i.Last.Sub(prev) + } + rw.WriteHeader(http.StatusServiceUnavailable) + } } diff --git a/pkg/receiver/jobreceiver/asset/integration_test.go b/pkg/receiver/jobreceiver/asset/integration_test.go new file mode 100644 index 0000000000..e54d25121f --- /dev/null +++ b/pkg/receiver/jobreceiver/asset/integration_test.go @@ -0,0 +1,63 @@ +package asset + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestManagerIntegration(t *testing.T) { + t.Parallel() + + srv := httptest.NewTLSServer(http.FileServer(http.Dir(fixturePath("")))) + defer srv.Close() + client := srv.Client() + + testStore, remove := tempDir(t) + defer remove() + nopLog := zap.NewNop().Sugar() + manager := Manager{ + Fetcher: NewFetcher(nopLog, client), + Logger: nopLog, + StoragePath: testStore, + } + specs := []Spec{ + { + Name: "helloworld.tar.gz", + SHA512: ExtractSHA(t, "helloworld.tar.gz.sha512"), + URL: srv.URL + "/helloworld.tar.gz", + }, { + Name: "helloworld.tar", + SHA512: ExtractSHA(t, "helloworld.tar.sha512"), + URL: srv.URL + "/helloworld.tar", + }, + } + err := manager.Validate(specs) + assert.NoError(t, err) + + env := []string{} + references, err := manager.InstallAll(context.Background(), specs) + require.NoError(t, err) + assert.NotEmpty(t, references) + for _, ref := range references { + info, err := os.Stat(ref.Path) + assert.NoError(t, err) + assert.True(t, info.IsDir()) + env = ref.MergeEnvironment(env) + } + var path string + for _, v := range env { + if strings.HasPrefix(v, "PATH") { + path = v + break + } + } + assert.NotEmpty(t, path, env) +} diff --git a/pkg/receiver/jobreceiver/asset/manager.go b/pkg/receiver/jobreceiver/asset/manager.go index 2885ae1ad3..ce8c6399ad 100644 --- a/pkg/receiver/jobreceiver/asset/manager.go +++ b/pkg/receiver/jobreceiver/asset/manager.go @@ -41,54 +41,33 @@ func (m *Manager) Validate(assets []Spec) error { // InstallAll runtime assets on the host file system under the StoragePath // directory. +// +// Loops through the provided asset Specs and ensures they are installed at the +// Manager's StoragePath. Returns the first error encountered in the install. func (m *Manager) InstallAll(ctx context.Context, all []Spec) ([]Reference, error) { - - type installTuple struct { - Err error - Ref Reference - } - - results := make(chan installTuple) - ictx, cancel := context.WithCancel(ctx) - defer cancel() + references := make([]Reference, 0, len(all)) for _, asset := range all { - go func(a Spec) { - var result installTuple - if ref, exists, err := m.get(a); exists { - m.Logger.With( - "asset", result.Ref.Name, - "path", result.Ref.Path, - ).Info("reusing previously installed runtime asset") - result.Ref = ref - } else if err != nil { - result.Err = fmt.Errorf("failed to access asset storage: %s", err) - } else { - result.Ref, result.Err = m.install(ictx, a) - if result.Err == nil { - m.Logger.With( - "asset", result.Ref.Name, - "path", result.Ref.Path, - ).Info("successfully installed asset") - } - } - results <- result - }(asset) - } - - var references []Reference - for i := 0; i < len(all); i++ { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case r := <-results: - if r.Err != nil { - return references, r.Err - } - references = append(references, r.Ref) + if ref, exists, err := m.get(asset); exists { + m.Logger.With( + "asset", ref.Name, + "path", ref.Path, + ).Info("reusing previously installed runtime asset") + references = append(references, ref) + continue + } else if err != nil { + return references, fmt.Errorf("failed to access asset storage: %s", err) + } + ref, err := m.install(ctx, asset) + if err != nil { + return references, fmt.Errorf("failed to retrieve runtime asset %s: %s", asset.Name, err) } + references = append(references, ref) + m.Logger.With( + "asset", asset.Name, + "path", ref.Path, + ).Info("successfully installed asset") } - return references, nil } diff --git a/pkg/receiver/jobreceiver/asset/verifier_test.go b/pkg/receiver/jobreceiver/asset/verifier_test.go index b6bdadb1df..d34bc21bf9 100644 --- a/pkg/receiver/jobreceiver/asset/verifier_test.go +++ b/pkg/receiver/jobreceiver/asset/verifier_test.go @@ -16,9 +16,7 @@ func TestVerify(t *testing.T) { f, err := os.Open(assetPath) require.NoError(t, err) - shaFileContent, err := os.ReadFile(fixturePath("helloworld.tar.gz.sha512")) - require.NoError(t, err) - expectedSha := strings.TrimSpace(string(shaFileContent)) + expectedSha := ExtractSHA(t, "helloworld.tar.gz.sha512") err = new(sha512Verifier).Verify(f, expectedSha) assert.NoError(t, err) @@ -27,3 +25,10 @@ func TestVerify(t *testing.T) { err = new(sha512Verifier).Verify(f, badSha) assert.Error(t, err) } + +func ExtractSHA(t *testing.T, fileName string) string { + t.Helper() + shaFileContent, err := os.ReadFile(fixturePath(fileName)) + require.NoError(t, err) + return strings.TrimSpace(string(shaFileContent)) +}