From 4fd2453752027724029894c798ec724a6fd76927 Mon Sep 17 00:00:00 2001
From: Billy Zha <jinzha1@microsoft.com>
Date: Fri, 3 Nov 2023 10:32:56 +0000
Subject: [PATCH] fix: atomic bool for layer skipped during pulling

Signed-off-by: Billy Zha <jinzha1@microsoft.com>
---
 cmd/oras/root/pull.go | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)

diff --git a/cmd/oras/root/pull.go b/cmd/oras/root/pull.go
index 844fd988b..e98b03691 100644
--- a/cmd/oras/root/pull.go
+++ b/cmd/oras/root/pull.go
@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"io"
 	"sync"
+	"sync/atomic"
 
 	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
 	"github.com/spf13/cobra"
@@ -130,7 +131,7 @@ func runPull(ctx context.Context, opts pullOptions) error {
 	dst.AllowPathTraversalOnWrite = opts.PathTraversal
 	dst.DisableOverwrite = opts.KeepOldFiles
 
-	desc, skippedLayers, err := doPull(ctx, src, dst, copyOptions, &opts)
+	desc, layerSkipped, err := doPull(ctx, src, dst, copyOptions, &opts)
 	if err != nil {
 		if errors.Is(err, file.ErrPathTraversalDisallowed) {
 			err = fmt.Errorf("%s: %w", "use flag --allow-path-traversal to allow insecurely pulling files outside of working directory", err)
@@ -139,7 +140,7 @@ func runPull(ctx context.Context, opts pullOptions) error {
 	}
 
 	// suggest oras copy for pulling layers without annotation
-	if skippedLayers > 0 {
+	if layerSkipped {
 		fmt.Printf("Skipped pulling layers without file name in %q\n", ocispec.AnnotationTitle)
 		fmt.Printf("Use 'oras copy %s --to-oci-layout <layout-dir>' to pull all layers.\n", opts.RawReference)
 	} else {
@@ -149,13 +150,13 @@ func runPull(ctx context.Context, opts pullOptions) error {
 	return nil
 }
 
-func doPull(ctx context.Context, src oras.ReadOnlyTarget, dst oras.GraphTarget, opts oras.CopyOptions, po *pullOptions) (ocispec.Descriptor, int, error) {
+func doPull(ctx context.Context, src oras.ReadOnlyTarget, dst oras.GraphTarget, opts oras.CopyOptions, po *pullOptions) (ocispec.Descriptor, bool, error) {
 	var configPath, configMediaType string
 	var err error
 	if po.ManifestConfigRef != "" {
 		configPath, configMediaType, err = fileref.Parse(po.ManifestConfigRef, "")
 		if err != nil {
-			return ocispec.Descriptor{}, 0, err
+			return ocispec.Descriptor{}, false, err
 		}
 	}
 
@@ -171,12 +172,12 @@ func doPull(ctx context.Context, src oras.ReadOnlyTarget, dst oras.GraphTarget,
 	var tracked track.GraphTarget
 	dst, tracked, err = getTrackedTarget(dst, po.TTY, "Downloading", "Pulled     ")
 	if err != nil {
-		return ocispec.Descriptor{}, 0, err
+		return ocispec.Descriptor{}, false, err
 	}
 	if tracked != nil {
 		defer tracked.Close()
 	}
-	skippedLayers := 0
+	var layerSkipped atomic.Bool
 	var printed sync.Map
 	var getConfigOnce sync.Once
 	opts.FindSuccessors = func(ctx context.Context, fetcher content.Fetcher, desc ocispec.Descriptor) ([]ocispec.Descriptor, error) {
@@ -236,7 +237,7 @@ func doPull(ctx context.Context, src oras.ReadOnlyTarget, dst oras.GraphTarget,
 				}
 				if s.Annotations[ocispec.AnnotationTitle] == "" {
 					// unnamed layers are skipped
-					skippedLayers++
+					layerSkipped.Store(true)
 				}
 				ss, err := content.Successors(ctx, fetcher, s)
 				if err != nil {
@@ -293,7 +294,7 @@ func doPull(ctx context.Context, src oras.ReadOnlyTarget, dst oras.GraphTarget,
 
 	// Copy
 	desc, err := oras.Copy(ctx, src, po.Reference, dst, po.Reference, opts)
-	return desc, skippedLayers, err
+	return desc, layerSkipped.Load(), err
 }
 
 // generateContentKey generates a unique key for each content descriptor, using