Skip to content

Commit

Permalink
Fix OCI test flakiness (#1447)
Browse files Browse the repository at this point in the history
Fixes #1435 

## Type of change

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Other (security config, docs update, etc)

## Checklist before merging

- [x] Test, docs, adr added or updated as needed
- [x] [Contributor Guide
Steps](https://github.com/defenseunicorns/zarf/blob/main/CONTRIBUTING.md#developer-workflow)
followed

---------

Signed-off-by: razzle <[email protected]>
Co-authored-by: Wayne Starr <[email protected]>
  • Loading branch information
Noxsios and Racer159 authored Mar 21, 2023
1 parent 4ce7ef3 commit 4227aca
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 113 deletions.
1 change: 0 additions & 1 deletion src/pkg/packager/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ func (p *Packager) handleOciPackage() error {
doneSaving := make(chan int)
var wg sync.WaitGroup
wg.Add(1)
src.ProgressBar = nil // NOTE: Disable this inbuilt progress bar so we don't double render a spinner
go utils.RenderProgressBarForLocalDirWrite(outDir, estimatedBytes, &wg, doneSaving, "Pulling Zarf package data")

copyOpts := oras.DefaultCopyOptions
Expand Down
12 changes: 6 additions & 6 deletions src/pkg/packager/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (p *Packager) publish(ref registry.Reference, paths []string) error {
root, err = p.publishArtifact(dst, src, descs, copyOpts)
if err != nil {
// reset the progress bar between attempts
dst.ProgressBar.Stop()
dst.Transport.ProgressBar.Stop()

// log the error, the expected error is a 400 manifest invalid
message.Debug("ArtifactManifest push failed with the following error, falling back to an ImageManifest push:", err)
Expand All @@ -192,7 +192,7 @@ func (p *Packager) publish(ref registry.Reference, paths []string) error {
return err
}
}
dst.ProgressBar.Successf("Published %s [%s]", ref, root.MediaType)
dst.Transport.ProgressBar.Successf("Published %s [%s]", ref, root.MediaType)
fmt.Println()
flags := ""
if config.CommonOptions.Insecure {
Expand All @@ -219,8 +219,8 @@ func (p *Packager) publishArtifact(dst *utils.OrasRemote, src *file.Store, descs
}
total += root.Size

dst.ProgressBar = message.NewProgressBar(total, fmt.Sprintf("Publishing %s:%s", dst.Reference.Repository, dst.Reference.Reference))
defer dst.ProgressBar.Stop()
dst.Transport.ProgressBar = message.NewProgressBar(total, fmt.Sprintf("Publishing %s:%s", dst.Reference.Repository, dst.Reference.Reference))
defer dst.Transport.ProgressBar.Stop()

// attempt to push the artifact manifest
_, err = oras.Copy(dst.Context, src, root.Digest.String(), dst, dst.Reference.Reference, copyOpts)
Expand Down Expand Up @@ -257,8 +257,8 @@ func (p *Packager) publishImage(dst *utils.OrasRemote, src *file.Store, descs []
}
total += root.Size + manifestConfigDesc.Size

dst.ProgressBar = message.NewProgressBar(total, fmt.Sprintf("Publishing %s:%s", dst.Reference.Repository, dst.Reference.Reference))
defer dst.ProgressBar.Stop()
dst.Transport.ProgressBar = message.NewProgressBar(total, fmt.Sprintf("Publishing %s:%s", dst.Reference.Repository, dst.Reference.Reference))
defer dst.Transport.ProgressBar.Stop()
// attempt to push the image manifest
_, err = oras.Copy(dst.Context, src, root.Digest.String(), dst, dst.Reference.Reference, copyOpts)
if err != nil {
Expand Down
97 changes: 97 additions & 0 deletions src/pkg/utils/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2021-Present The Zarf Authors

// Package utils provides generic helper functions.
package utils

import (
"io"
"net/http"
"time"

"github.com/defenseunicorns/zarf/src/pkg/message"
"oras.land/oras-go/v2/registry/remote/retry"
)

// Transport is an http.RoundTripper that keeps track of the in-flight
// request and add hooks to report upload progress.
type Transport struct {
Base http.RoundTripper
ProgressBar *message.ProgressBar
}

// NewTransport returns a custom transport that tracks an http.RoundTripper and a message.ProgressBar.
func NewTransport(base http.RoundTripper, bar *message.ProgressBar) *Transport {
return &Transport{
Base: base,
ProgressBar: bar,
}
}

// RoundTrip is mirrored from retry, but instead of calling retry's private t.roundTrip(), this uses
// our own which has interactions w/ message.ProgressBar
//
// https://github.com/oras-project/oras-go/blob/main/registry/remote/retry/client.go
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
policy := retry.DefaultPolicy
attempt := 0
for {
resp, respErr := t.roundTrip(req)
duration, err := policy.Retry(attempt, resp, respErr)
if err != nil {
if respErr == nil {
resp.Body.Close()
}
return nil, err
}
if duration < 0 {
return resp, respErr
}

// rewind the body if possible
if req.Body != nil {
if req.GetBody == nil {
// body can't be rewound, so we can't retry
return resp, respErr
}
body, err := req.GetBody()
if err != nil {
// failed to rewind the body, so we can't retry
return resp, respErr
}
req.Body = body
}

// close the response body if needed
if respErr == nil {
resp.Body.Close()
}

timer := time.NewTimer(duration)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
attempt++
}
}

// roundTrip calls base roundtrip while keeping track of the current request.
// this is currently only used to track the progress of publishes, not pulls.
func (t *Transport) roundTrip(req *http.Request) (resp *http.Response, err error) {
if req.Method != http.MethodHead && req.Body != nil && t.ProgressBar != nil {
req.Body = io.NopCloser(io.TeeReader(req.Body, t.ProgressBar))
}

resp, err = t.Base.RoundTrip(req)

if resp != nil && req.Method == http.MethodHead && err == nil && t.ProgressBar != nil {
if resp.ContentLength > 0 {
t.ProgressBar.Add(int(resp.ContentLength))
}
}
return resp, err
}
45 changes: 4 additions & 41 deletions src/pkg/utils/oras.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net/http"

zarfconfig "github.com/defenseunicorns/zarf/src/config"
Expand All @@ -19,14 +18,13 @@ import (
"oras.land/oras-go/v2/registry"
"oras.land/oras-go/v2/registry/remote"
"oras.land/oras-go/v2/registry/remote/auth"
"oras.land/oras-go/v2/registry/remote/retry"
)

// OrasRemote is a wrapper around the Oras remote repository that includes a progress bar for interactive feedback.
type OrasRemote struct {
*remote.Repository
context.Context
*message.ProgressBar
Transport *Transport
}

// withScopes returns a context with the given scopes.
Expand Down Expand Up @@ -75,55 +73,20 @@ func (o *OrasRemote) withAuthClient(ref registry.Reference) (*auth.Client, error
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig.InsecureSkipVerify = zarfconfig.CommonOptions.Insecure

o.Transport = NewTransport(transport, nil)

client := &auth.Client{
Credential: auth.StaticCredential(ref.Registry, cred),
Cache: auth.NewCache(),
Client: &http.Client{
Transport: retry.NewTransport(transport),
Transport: o.Transport,
},
}
client.SetUserAgent("zarf/" + zarfconfig.CLIVersion)

client.Client.Transport = NewTransport(client.Client.Transport, o)

return client, nil
}

// Transport is an http.RoundTripper that keeps track of the in-flight
// request and add hooks to report HTTP tracing events.
type Transport struct {
http.RoundTripper
orasRemote *OrasRemote
}

// NewTransport returns a custom transport that tracks an http.RoundTripper and an OrasRemote reference.
func NewTransport(base http.RoundTripper, o *OrasRemote) *Transport {
return &Transport{base, o}
}

type readCloser struct {
io.Reader
io.Closer
}

// RoundTrip calls base roundtrip while keeping track of the current request.
// This is currently only used to track the progress of publishes, not pulls.
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
if req.Method != http.MethodHead && req.Body != nil && t.orasRemote.ProgressBar != nil {
tee := io.TeeReader(req.Body, t.orasRemote.ProgressBar)
teeCloser := readCloser{tee, req.Body}
req.Body = teeCloser
}

resp, err = t.RoundTripper.RoundTrip(req)

if req.Method == http.MethodHead && err == nil && t.orasRemote.ProgressBar != nil && resp.ContentLength > 0 {
t.orasRemote.ProgressBar.Add(int(resp.ContentLength))
}

return resp, err
}

// NewOrasRemote returns an oras remote repository client and context for the given reference.
func NewOrasRemote(ref registry.Reference) (*OrasRemote, error) {
o := &OrasRemote{}
Expand Down
Loading

0 comments on commit 4227aca

Please sign in to comment.