Skip to content

Commit

Permalink
Clean up asset fetching logic and add basic tests
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Kruse <[email protected]>
  • Loading branch information
c-kruse committed Oct 6, 2023
1 parent 1473df4 commit 82dc615
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 70 deletions.
14 changes: 8 additions & 6 deletions pkg/receiver/jobreceiver/asset/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@ type Fetcher interface {

// An HTTPFetcher fetches the contents of files at a given URL.
type httpFetcher struct {
client *http.Client
logger *zap.SugaredLogger
client *http.Client
logger *zap.SugaredLogger
makeBackoff func() backoff.BackOff
}

// NewFetcher creates a new HTTP based Fetcher.
//
// Uses an exponential backoff to retry failed requests.
func NewFetcher(log *zap.SugaredLogger, client *http.Client) Fetcher {
return &httpFetcher{
client: client,
logger: log,
client: client,
logger: log,
makeBackoff: func() backoff.BackOff { return backoff.NewExponentialBackOff() },
}
}

Expand All @@ -40,7 +42,7 @@ func NewFetcher(log *zap.SugaredLogger, client *http.Client) Fetcher {
func (h *httpFetcher) Fetch(ctx context.Context, url string) (*os.File, error) {
var fetchErr error
var attempts int
b := backoff.NewExponentialBackOff()
b := h.makeBackoff()
for {
duration := b.NextBackOff()
if duration == backoff.Stop {
Expand All @@ -56,7 +58,7 @@ func (h *httpFetcher) Fetch(ctx context.Context, url string) (*os.File, error) {
h.logger.Errorf("retrying failed asset fetch for %s: %s", url, fetchErr)
}

attempts++
attempts = attempts + 1
out, err := h.tryFetch(ctx, url)
if err != nil {
fetchErr = err
Expand Down
194 changes: 176 additions & 18 deletions pkg/receiver/jobreceiver/asset/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@ package asset

import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

func TestFetcher(t *testing.T) {
t.Parallel()

srv := httptest.NewTLSServer(http.FileServer(http.Dir(fixturePath(""))))
defer srv.Close()
client := srv.Client()
Expand All @@ -25,25 +28,180 @@ func TestFetcher(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

helloTGZURL, err := url.JoinPath(srv.URL, "helloworld.tar.gz")
require.NoError(t, err)
f, err := fetcher.Fetch(ctx, helloTGZURL)
require.NoError(t, err)
t.Run("can fetch tar gz file", func(t *testing.T) {
helloTGZURL, err := url.JoinPath(srv.URL, "helloworld.tar.gz")
require.NoError(t, err)
f, err := fetcher.Fetch(ctx, helloTGZURL)
require.NoError(t, err)

tarGZShaBytes, err := os.ReadFile(fixturePath("helloworld.tar.gz.sha512"))
require.NoError(t, err)
tarGZSHA := strings.TrimSpace(string(tarGZShaBytes))
err = new(sha512Verifier).Verify(f, tarGZSHA)
assert.NoError(t, err)
tarGZSHA := ExtractSHA(t, "helloworld.tar.gz.sha512")
err = new(sha512Verifier).Verify(f, tarGZSHA)
assert.NoError(t, err)
})

helloTURL, err := url.JoinPath(srv.URL, "helloworld.tar")
require.NoError(t, err)
f, err = fetcher.Fetch(ctx, helloTURL)
require.NoError(t, err)
t.Run("can fetch tar file", func(t *testing.T) {
helloTURL, err := url.JoinPath(srv.URL, "helloworld.tar")
require.NoError(t, err)
f, err := fetcher.Fetch(ctx, helloTURL)
require.NoError(t, err)

tarSHA := ExtractSHA(t, "helloworld.tar.sha512")
err = new(sha512Verifier).Verify(f, tarSHA)
assert.NoError(t, err)
})
}

func TestFetcherBackOff(t *testing.T) {
t.Parallel()

observedIntervals := make(chan time.Duration, 16)
respondOK := make(chan []byte, 1)
srv := httptest.NewTLSServer(&intervalObserver{
Response: respondOK,
Out: observedIntervals,
})
defer srv.Close()
fetcher := NewFetcher(zap.NewNop().Sugar(), srv.Client())
fetcher.(*httpFetcher).makeBackoff = func() backoff.BackOff {
return backoff.NewConstantBackOff(time.Millisecond * 150)
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

fetchDone := make(chan struct{})
var fetchResponse *os.File
var fetchErr error
go func() {
fetchResponse, fetchErr = fetcher.Fetch(ctx, srv.URL)
close(fetchDone)
}()

intervals := make([]time.Duration, 0, 3)
assert.Eventually(t, func() bool {
select {
case interval := <-observedIntervals:
intervals = append(intervals, interval)
default:
}
return len(intervals) >= 3
}, time.Second*5, time.Millisecond*150, "expected fetcher to retry repeatedly")

// retried according to backoff schedule with .3 to 10x tolerance
for _, interval := range intervals {
assert.Less(t, time.Millisecond*50, interval)
assert.Greater(t, time.Millisecond*1500, interval)
}

// respond to fetcher sucessfully with payload
respondOK <- []byte("sensu")

assert.Eventually(t, func() bool {
select {
case <-fetchDone:
return true
default:
return false
}
}, time.Second*5, time.Millisecond*150, "expected fetcher to return")

tarShaBytes, err := os.ReadFile(fixturePath("helloworld.tar.sha512"))
require.NoError(t, fetchErr)
defer func() { assert.NoError(t, fetchResponse.Close()) }()
actualText, err := io.ReadAll(fetchResponse)
require.NoError(t, err)
tarSHA := strings.TrimSpace(string(tarShaBytes))
err = new(sha512Verifier).Verify(f, tarSHA)
assert.NoError(t, err)
assert.Equal(t, "sensu", string(actualText))
}

func TestFetcherBackOffExpires(t *testing.T) {
t.Parallel()

observedIntervals := make(chan time.Duration, 16)
srv := httptest.NewTLSServer(&intervalObserver{
Response: make(chan []byte, 1),
Out: observedIntervals,
})
defer srv.Close()
fetcher := NewFetcher(zap.NewNop().Sugar(), srv.Client())
// set backoff max elapsed time to .5s
fetcher.(*httpFetcher).makeBackoff = func() backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.InitialInterval = 50 * time.Millisecond
b.Multiplier = 1
b.RandomizationFactor = 0.1
b.MaxElapsedTime = 2 * time.Second
return b
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

fetchDone := make(chan struct{})
var fetchResponse *os.File
var fetchErr error
go func() {
fetchResponse, fetchErr = fetcher.Fetch(ctx, srv.URL)
close(fetchDone)
}()

atLeastOneRetry := make(chan struct{})
go func() {
<-observedIntervals
close(atLeastOneRetry)
for {
select {
case <-observedIntervals:
case <-ctx.Done():
return
}
}
}()
assert.Eventually(t, func() bool {
select {
case <-atLeastOneRetry:
return true
default:
return false
}
}, time.Second*5, time.Millisecond*150, "expected at least one retry")

assert.Eventually(t, func() bool {
select {
case <-fetchDone:
return true
default:
return false
}
}, time.Second*2, time.Millisecond*50, "expected fetcher to return")

assert.Error(t, fetchErr)
if fetchResponse != nil {
assert.NoError(t, fetchResponse.Close())
}
}

type intervalObserver struct {
Last time.Time

Response <-chan []byte
Out chan<- time.Duration
}

var _ http.Handler = (*intervalObserver)(nil)

func (i *intervalObserver) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
prev := i.Last
i.Last = time.Now()

select {
case resp := <-i.Response:
_, err := rw.Write(resp)
if err != nil {
panic(err)
}
default:
if !prev.IsZero() {
i.Out <- i.Last.Sub(prev)
}
rw.WriteHeader(http.StatusServiceUnavailable)
}
}
63 changes: 63 additions & 0 deletions pkg/receiver/jobreceiver/asset/integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package asset

import (
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

func TestManagerIntegration(t *testing.T) {
t.Parallel()

srv := httptest.NewTLSServer(http.FileServer(http.Dir(fixturePath(""))))
defer srv.Close()
client := srv.Client()

testStore, remove := tempDir(t)
defer remove()
nopLog := zap.NewNop().Sugar()
manager := Manager{
Fetcher: NewFetcher(nopLog, client),
Logger: nopLog,
StoragePath: testStore,
}
specs := []Spec{
{
Name: "helloworld.tar.gz",
SHA512: ExtractSHA(t, "helloworld.tar.gz.sha512"),
URL: srv.URL + "/helloworld.tar.gz",
}, {
Name: "helloworld.tar",
SHA512: ExtractSHA(t, "helloworld.tar.sha512"),
URL: srv.URL + "/helloworld.tar",
},
}
err := manager.Validate(specs)
assert.NoError(t, err)

env := []string{}
references, err := manager.InstallAll(context.Background(), specs)
require.NoError(t, err)
assert.NotEmpty(t, references)
for _, ref := range references {
info, err := os.Stat(ref.Path)
assert.NoError(t, err)
assert.True(t, info.IsDir())
env = ref.MergeEnvironment(env)
}
var path string
for _, v := range env {
if strings.HasPrefix(v, "PATH") {
path = v
break
}
}
assert.NotEmpty(t, path, env)
}
Loading

0 comments on commit 82dc615

Please sign in to comment.