diff --git a/src/internal/packager/images/pull.go b/src/internal/packager/images/pull.go index ee058c6e67..ce13f134f0 100644 --- a/src/internal/packager/images/pull.go +++ b/src/internal/packager/images/pull.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "io" "os" "path/filepath" "strings" @@ -23,6 +24,10 @@ import ( v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/cache" "github.com/google/go-containerregistry/pkg/v1/daemon" + "github.com/google/go-containerregistry/pkg/v1/empty" + "github.com/google/go-containerregistry/pkg/v1/layout" + "github.com/google/go-containerregistry/pkg/v1/partial" + "github.com/google/go-containerregistry/pkg/v1/stream" "github.com/moby/moby/client" "github.com/pterm/pterm" ) @@ -50,27 +55,73 @@ func (i *ImgConfig) PullAll() error { logs.Warn.SetOutput(&message.DebugWriter{}) logs.Progress.SetOutput(&message.DebugWriter{}) - for idx, src := range i.ImgList { - spinner.Updatef("Fetching image metadata (%d of %d): %s", idx+1, imgCount, src) + type srcAndImg struct { + src string + img v1.Image + } - srcParsed, err := transform.ParseImageRef(src) - if err != nil { - return fmt.Errorf("failed to parse image ref %s: %w", src, err) - } + metadataImageConcurrency := utils.NewConcurrencyTools[srcAndImg, error](len(i.ImgList)) + + defer metadataImageConcurrency.Cancel() - actualSrc := src - if overrideHost, present := i.RegistryOverrides[srcParsed.Host]; present { - actualSrc, err = transform.ImageTransformHostWithoutChecksum(overrideHost, src) + spinner.Updatef("Fetching image metadata (0 of %d)", len(i.ImgList)) + + // Spawn a goroutine for each image to load its metadata + for _, src := range i.ImgList { + // Create a closure so that we can pass the src into the goroutine + src := src + go func() { + // Make sure to call Done() on the WaitGroup when the goroutine finishes + defer metadataImageConcurrency.WaitGroupDone() + + srcParsed, err := transform.ParseImageRef(src) if err != nil { - return fmt.Errorf("failed to swap override host %s for %s: %w", overrideHost, src, err) + metadataImageConcurrency.ErrorChan <- fmt.Errorf("failed to parse image ref %s: %w", src, err) + return } - } - img, err := i.PullImage(actualSrc, spinner) - if err != nil { - return fmt.Errorf("failed to pull image %s: %w", actualSrc, err) - } - imageMap[src] = img + if metadataImageConcurrency.IsDone() { + return + } + + actualSrc := src + if overrideHost, present := i.RegistryOverrides[srcParsed.Host]; present { + actualSrc, err = transform.ImageTransformHostWithoutChecksum(overrideHost, src) + if err != nil { + metadataImageConcurrency.ErrorChan <- fmt.Errorf("failed to swap override host %s for %s: %w", overrideHost, src, err) + return + } + } + + if metadataImageConcurrency.IsDone() { + return + } + + img, err := i.PullImage(actualSrc, spinner) + if err != nil { + metadataImageConcurrency.ErrorChan <- fmt.Errorf("failed to pull image %s: %w", actualSrc, err) + return + } + + if metadataImageConcurrency.IsDone() { + return + } + + metadataImageConcurrency.ProgressChan <- srcAndImg{src: src, img: img} + }() + } + + onMetadataProgress := func(finishedImage srcAndImg, iteration int) { + spinner.Updatef("Fetching image metadata (%d of %d): %s", iteration+1, len(i.ImgList), finishedImage.src) + imageMap[finishedImage.src] = finishedImage.img + } + + onMetadataError := func(err error) error { + return fmt.Errorf("Failed to load metadata for all images. This may be due to a network error or an invalid image reference: %w", err) + } + + if err := metadataImageConcurrency.WaitWithProgress(onMetadataProgress, onMetadataError); err != nil { + return err } // Create the ImagePath directory @@ -80,7 +131,7 @@ func (i *ImgConfig) PullAll() error { } totalBytes := int64(0) - processedLayers := make(map[string]bool) + processedLayers := make(map[string]v1.Layer) for src, img := range imageMap { tag, err := name.NewTag(src, name.WeakValidation) if err != nil { @@ -99,42 +150,281 @@ func (i *ImgConfig) PullAll() error { } // Only calculate this layer size if we haven't already looked at it - if !processedLayers[layerDigest.Hex] { + if _, ok := processedLayers[layerDigest.Hex]; !ok { size, err := layer.Size() if err != nil { return fmt.Errorf("unable to get size of layer: %w", err) } totalBytes += size - processedLayers[layerDigest.Hex] = true + processedLayers[layerDigest.Hex] = layer } } } spinner.Updatef("Preparing image sources and cache for image pulling") + + type digestAndTag struct { + digest string + tag string + } + + // Create special sauce crane Path object + // If it already exists use it + cranePath, err := layout.FromPath(i.ImagesPath) + // Use crane pattern for creating OCI layout if it doesn't exist + if err != nil { + // If it doesn't exist create it + cranePath, err = layout.Write(i.ImagesPath, empty.Index) + if err != nil { + return err + } + } + + for tag, img := range tagToImage { + imgDigest, err := img.Digest() + if err != nil { + return fmt.Errorf("unable to get digest for image %s: %w", tag, err) + } + tagToDigest[tag.String()] = imgDigest.String() + } + spinner.Success() // Create a thread to update a progress bar as we save the image files to disk doneSaving := make(chan int) - var wg sync.WaitGroup - wg.Add(1) - go utils.RenderProgressBarForLocalDirWrite(i.ImagesPath, totalBytes, &wg, doneSaving, fmt.Sprintf("Pulling %d images", imgCount)) + var progressBarWaitGroup sync.WaitGroup + progressBarWaitGroup.Add(1) + go utils.RenderProgressBarForLocalDirWrite(i.ImagesPath, totalBytes, &progressBarWaitGroup, doneSaving, fmt.Sprintf("Pulling %d images", imgCount)) + + // Spawn a goroutine for each layer to write it to disk using crane + + layerWritingConcurrency := utils.NewConcurrencyTools[bool, error](len(processedLayers)) + + defer layerWritingConcurrency.Cancel() + + for _, layer := range processedLayers { + layer := layer + // Function is a combination of https://github.com/google/go-containerregistry/blob/v0.15.2/pkg/v1/layout/write.go#L270-L305 + // and https://github.com/google/go-containerregistry/blob/v0.15.2/pkg/v1/layout/write.go#L198-L262 + // with modifications. This allows us to dedupe layers for all images and write them concurrently. + go func() { + defer layerWritingConcurrency.WaitGroupDone() + digest, err := layer.Digest() + if errors.Is(err, stream.ErrNotComputed) { + // Allow digest errors, since streams may not have calculated the hash + // yet. Instead, use an empty value, which will be transformed into a + // random file name with `os.CreateTemp` and the final digest will be + // calculated after writing to a temp file and before renaming to the + // final path. + digest = v1.Hash{Algorithm: "sha256", Hex: ""} + } else if err != nil { + layerWritingConcurrency.ErrorChan <- err + return + } + + size, err := layer.Size() + if errors.Is(err, stream.ErrNotComputed) { + // Allow size errors, since streams may not have calculated the size + // yet. Instead, use -1 as a sentinel value meaning that no size + // comparison can be done and any sized blob file should be considered + // valid and not overwritten. + // + // TODO: Provide an option to always overwrite blobs. + size = -1 + } else if err != nil { + layerWritingConcurrency.ErrorChan <- err + return + } + + if layerWritingConcurrency.IsDone() { + return + } + + readCloser, err := layer.Compressed() + if err != nil { + layerWritingConcurrency.ErrorChan <- err + return + } + + // Create the directory for the blob if it doesn't exist + dir := filepath.Join(string(cranePath), "blobs", digest.Algorithm) + if err := utils.CreateDirectory(dir, os.ModePerm); err != nil { + layerWritingConcurrency.ErrorChan <- err + return + } + + if layerWritingConcurrency.IsDone() { + return + } + + // Check if blob already exists and is the correct size + file := filepath.Join(dir, digest.Hex) + if s, err := os.Stat(file); err == nil && !s.IsDir() && (s.Size() == size || size == -1) { + layerWritingConcurrency.ProgressChan <- true + return + } + + if layerWritingConcurrency.IsDone() { + return + } + // Write to a temporary file + w, err := os.CreateTemp(dir, digest.Hex) + if err != nil { + layerWritingConcurrency.ErrorChan <- err + return + } + // Delete temp file if an error is encountered before renaming + defer func() { + if err := os.Remove(w.Name()); err != nil && !errors.Is(err, os.ErrNotExist) { + message.Warnf("error removing temporary file after encountering an error while writing blob: %v", err) + } + }() + + defer w.Close() + + if layerWritingConcurrency.IsDone() { + return + } + + // Write to file rename + if n, err := io.Copy(w, readCloser); err != nil { + layerWritingConcurrency.ErrorChan <- err + return + } else if size != -1 && n != size { + layerWritingConcurrency.ErrorChan <- fmt.Errorf("expected blob size %d, but only wrote %d", size, n) + return + } + + if layerWritingConcurrency.IsDone() { + return + } + + // Always close reader before renaming, since Close computes the digest in + // the case of streaming layers. If Close is not called explicitly, it will + // occur in a goroutine that is not guaranteed to succeed before renamer is + // called. When renamer is the layer's Digest method, it can return + // ErrNotComputed. + if err := readCloser.Close(); err != nil { + layerWritingConcurrency.ErrorChan <- err + return + } + + // Always close file before renaming + if err := w.Close(); err != nil { + layerWritingConcurrency.ErrorChan <- err + return + } + + // Rename file based on the final hash + renamePath := filepath.Join(string(cranePath), "blobs", digest.Algorithm, digest.Hex) + os.Rename(w.Name(), renamePath) + + if layerWritingConcurrency.IsDone() { + return + } + + layerWritingConcurrency.ProgressChan <- true + }() + } + + onLayerWritingError := func(err error) error { + // Send a signal to the progress bar that we're done and wait for the thread to finish + doneSaving <- 1 + progressBarWaitGroup.Wait() + message.WarnErr(err, "Failed to write image layers, trying again up to 3 times...") + if strings.HasPrefix(err.Error(), "expected blob size") { + message.Warnf("Potential image cache corruption: %s - try clearing cache with \"zarf tools clear-cache\"", err.Error()) + } + return err + } + + if err := layerWritingConcurrency.WaitWithoutProgress(onLayerWritingError); err != nil { + return err + } + + imageSavingConcurrency := utils.NewConcurrencyTools[digestAndTag, error](len(tagToImage)) + + defer imageSavingConcurrency.Cancel() + + // Spawn a goroutine for each image to write it's config and manifest to disk using crane + // All layers should already be in place so this should be extremely fast for tag, img := range tagToImage { - // Save the image - err := crane.SaveOCI(img, i.ImagesPath) - if err != nil { - // Check if the cache has been invalidated, and warn the user if so - if strings.HasPrefix(err.Error(), "error writing layer: expected blob size") { - message.Warnf("Potential image cache corruption: %s - try clearing cache with \"zarf tools clear-cache\"", err.Error()) + // Create a closure so that we can pass the tag and img into the goroutine + tag, img := tag, img + go func() { + // Make sure to call Done() on the WaitGroup when the goroutine finishes + defer imageSavingConcurrency.WaitGroupDone() + + // Save the image via crane + err := cranePath.WriteImage(img) + + if imageSavingConcurrency.IsDone() { + return } - return fmt.Errorf("error when trying to save the img (%s): %w", tag.Name(), err) + + if err != nil { + // Check if the cache has been invalidated, and warn the user if so + if strings.HasPrefix(err.Error(), "error writing layer: expected blob size") { + message.Warnf("Potential image cache corruption: %s - try clearing cache with \"zarf tools clear-cache\"", err.Error()) + } + imageSavingConcurrency.ErrorChan <- fmt.Errorf("error when trying to save the img (%s): %w", tag.Name(), err) + return + } + + if imageSavingConcurrency.IsDone() { + return + } + + // Get the image digest so we can set an annotation in the image.json later + imgDigest, err := img.Digest() + if err != nil { + imageSavingConcurrency.ErrorChan <- err + return + } + + if imageSavingConcurrency.IsDone() { + return + } + + imageSavingConcurrency.ProgressChan <- digestAndTag{digest: imgDigest.String(), tag: tag.String()} + }() + } + + onImageSavingProgress := func(finishedImage digestAndTag, iteration int) { + tagToDigest[finishedImage.tag] = finishedImage.digest + } + + onImageSavingError := func(err error) error { + // Send a signal to the progress bar that we're done and wait for the thread to finish + doneSaving <- 1 + progressBarWaitGroup.Wait() + message.WarnErr(err, "Failed to write image config or manifest, trying again up to 3 times...") + return err + } + + if err := imageSavingConcurrency.WaitWithProgress(onImageSavingProgress, onImageSavingError); err != nil { + return err + } + + // for every image sequentially append OCI descriptor + + for tag, img := range tagToImage { + desc, err := partial.Descriptor(img) + if err != nil { + return err + } + + cranePath.AppendDescriptor(*desc) + if err != nil { + return err } - // Get the image digest so we can set an annotation in the image.json later imgDigest, err := img.Digest() if err != nil { return err } + tagToDigest[tag.String()] = imgDigest.String() } @@ -142,9 +432,9 @@ func (i *ImgConfig) PullAll() error { return fmt.Errorf("unable to format OCI layout: %w", err) } - // Send a signal to the progress bar that we're done and ait for the thread to finish + // Send a signal to the progress bar that we're done and wait for the thread to finish doneSaving <- 1 - wg.Wait() + progressBarWaitGroup.Wait() return err } diff --git a/src/pkg/utils/concurrency.go b/src/pkg/utils/concurrency.go new file mode 100644 index 0000000000..7fc80e7004 --- /dev/null +++ b/src/pkg/utils/concurrency.go @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2021-Present The Zarf Authors +// forked from https://www.socketloop.com/tutorials/golang-byte-format-example + +// Package utils provides generic helper functions. +package utils + +import ( + "context" + "sync" +) + +// ConcurrencyTools is a struct that contains channels and a context for use in concurrent routines +type ConcurrencyTools[P any, E any] struct { + ProgressChan chan P + ErrorChan chan E + context context.Context + Cancel context.CancelFunc + waitGroup *sync.WaitGroup + routineCount int +} + +// NewConcurrencyTools creates a new ConcurrencyTools struct +// +// Length is the number of iterations that will be performed concurrently +func NewConcurrencyTools[P any, E any](length int) *ConcurrencyTools[P, E] { + ctx, cancel := context.WithCancel(context.TODO()) + + progressChan := make(chan P, length) + + errorChan := make(chan E, length) + + waitGroup := sync.WaitGroup{} + + waitGroup.Add(length) + + concurrencyTools := ConcurrencyTools[P, E]{ + ProgressChan: progressChan, + ErrorChan: errorChan, + context: ctx, + Cancel: cancel, + waitGroup: &waitGroup, + routineCount: length, + } + + return &concurrencyTools +} + +// IsDone returns true if the context is done. +func (ct *ConcurrencyTools[P, E]) IsDone() bool { + ctx := ct.context + select { + case <-ctx.Done(): + return true + default: + return false + } +} + +// WaitGroupDone decrements the internal WaitGroup counter by one. +func (ct *ConcurrencyTools[P, E]) WaitGroupDone() { + ct.waitGroup.Done() +} + +// WaitWithProgress waits for all routines to finish +// +// onProgress is a callback function that is called when a routine sends a progress update +// +// onError is a callback function that is called when a routine sends an error +func (ct *ConcurrencyTools[P, E]) WaitWithProgress(onProgress func(P, int), onError func(E) error) error { + for i := 0; i < ct.routineCount; i++ { + select { + case err := <-ct.ErrorChan: + ct.Cancel() + errResult := onError(err) + return errResult + case progress := <-ct.ProgressChan: + onProgress(progress, i) + } + } + ct.waitGroup.Wait() + return nil +} + +// WaitWithoutProgress waits for all routines to finish without a progress callback +// +// onError is a callback function that is called when a routine sends an error +func (ct *ConcurrencyTools[P, E]) WaitWithoutProgress(onError func(E) error) error { + for i := 0; i < ct.routineCount; i++ { + select { + case err := <-ct.ErrorChan: + ct.Cancel() + errResult := onError(err) + return errResult + case <-ct.ProgressChan: + } + } + ct.waitGroup.Wait() + return nil +}