From d05b99a905ea37320b4598a259789dcc3115d88c Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Wed, 14 Jul 2021 17:43:29 +0000 Subject: [PATCH] feat: add checksum validation `bldr validate --checksums`. Refs #64. Signed-off-by: Alexey Palazhchenko --- cmd/update.go | 147 ++++++++++++--------- cmd/validate.go | 83 +++++++++++- internal/pkg/types/v1alpha2/source.go | 58 +++++++- internal/pkg/types/v1alpha2/source_test.go | 52 ++++++++ 4 files changed, 272 insertions(+), 68 deletions(-) create mode 100644 internal/pkg/types/v1alpha2/source_test.go diff --git a/cmd/update.go b/cmd/update.go index b591a6a..5c7bf5e 100644 --- a/cmd/update.go +++ b/cmd/update.go @@ -30,6 +30,87 @@ type updateInfo struct { *update.LatestInfo } +//nolint:gocyclo +func checkUpdates(ctx context.Context, set solver.PackageSet, l *log.Logger) error { + var ( + wg sync.WaitGroup + concurrency = runtime.GOMAXPROCS(-1) + sources = make(chan *pkgInfo) + updates = make(chan *updateInfo) + ) + + // start updaters + for i := 0; i < concurrency; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for src := range sources { + res, e := update.Latest(ctx, src.source) + if e != nil { + l.Print(e) + continue + } + + updates <- &updateInfo{ + file: src.file, + LatestInfo: res, + } + } + }() + } + + var ( + res []updateInfo + done = make(chan struct{}) + ) + + // start results reader + go func() { + for update := range updates { + res = append(res, *update) + } + + close(done) + }() + + // send work to updaters + for _, node := range set { + for _, step := range node.Pkg.Steps { + for _, src := range step.Sources { + sources <- &pkgInfo{ + file: node.Pkg.FileName, + source: src.URL, + } + } + } + } + + close(sources) + wg.Wait() + close(updates) + <-done + + sort.Slice(res, func(i, j int) bool { return res[i].file < res[j].file }) + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintf(w, "%s\t%s\t%s\n", "File", "Update", "URL") + + for _, info := range res { + if updateCmdFlag.all || info.HasUpdate { + url := info.LatestURL + if url == "" { + url = info.BaseURL + } + + fmt.Fprintf(w, "%s\t%t\t%s\n", info.file, info.HasUpdate, url) + } + } + + return w.Flush() +} + var updateCmdFlag struct { all bool dry bool @@ -59,71 +140,7 @@ var updateCmd = &cobra.Command{ l.SetOutput(ioutil.Discard) } - concurrency := runtime.GOMAXPROCS(-1) - var wg sync.WaitGroup - sources := make(chan *pkgInfo) - updates := make(chan *updateInfo) - for i := 0; i < concurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - - for src := range sources { - res, e := update.Latest(context.TODO(), src.source) - if e != nil { - l.Print(e) - continue - } - - updates <- &updateInfo{ - file: src.file, - LatestInfo: res, - } - } - }() - } - - var res []updateInfo - done := make(chan struct{}) - go func() { - for update := range updates { - res = append(res, *update) - } - close(done) - }() - - for _, node := range packages.ToSet() { - for _, step := range node.Pkg.Steps { - for _, src := range step.Sources { - sources <- &pkgInfo{ - file: node.Pkg.FileName, - source: src.URL, - } - } - } - } - close(sources) - wg.Wait() - close(updates) - <-done - - sort.Slice(res, func(i, j int) bool { return res[i].file < res[j].file }) - - w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) - fmt.Fprintf(w, "%s\t%s\t%s\n", "File", "Update", "URL") - - for _, info := range res { - if updateCmdFlag.all || info.HasUpdate { - url := info.LatestURL - if url == "" { - url = info.BaseURL - } - - fmt.Fprintf(w, "%s\t%t\t%s\n", info.file, info.HasUpdate, url) - } - } - - if err = w.Flush(); err != nil { + if err = checkUpdates(context.TODO(), packages.ToSet(), l); err != nil { log.Fatal(err) } }, diff --git a/cmd/validate.go b/cmd/validate.go index 4300a96..5439fec 100644 --- a/cmd/validate.go +++ b/cmd/validate.go @@ -5,12 +5,81 @@ package cmd import ( + "context" + "fmt" + "io/ioutil" "log" + "runtime" + "sync" + "github.com/hashicorp/go-multierror" "github.com/spf13/cobra" + "github.com/talos-systems/bldr/internal/pkg/solver" + "github.com/talos-systems/bldr/internal/pkg/types/v1alpha2" ) +func validateChecksums(ctx context.Context, set solver.PackageSet, l *log.Logger) error { + var ( + wg sync.WaitGroup + concurrency = runtime.GOMAXPROCS(-1) + pkgs = make(chan *v1alpha2.Pkg) + errors = make(chan error) + ) + + // start downloaders + for i := 0; i < concurrency; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for pkg := range pkgs { + for _, step := range pkg.Steps { + for _, src := range step.Sources { + l.Printf("downloading %s ...", src.URL) + + _, _, err := src.ValidateChecksums(ctx) + if err != nil { + errors <- fmt.Errorf("%s: %w", pkg.Name, err) + } + } + } + } + }() + } + + var ( + multiErr *multierror.Error + done = make(chan struct{}) + ) + + // start results reader + go func() { + for err := range errors { + multiErr = multierror.Append(multiErr, err) + } + + close(done) + }() + + // send work to downloaders + for _, node := range set { + pkgs <- node.Pkg + } + + close(pkgs) + wg.Wait() + close(errors) + <-done + + return multiErr.ErrorOrNil() +} + +var validateCmdFlags struct { + checksums bool +} + // validateCmd represents the validate command. var validateCmd = &cobra.Command{ Use: "validate", @@ -23,13 +92,25 @@ loads them and validates for errors. `, Context: options.GetVariables(), } - _, err := solver.NewPackages(&loader) + packages, err := solver.NewPackages(&loader) if err != nil { log.Fatal(err) } + + if validateCmdFlags.checksums { + l := log.New(log.Writer(), "[validate] ", log.Flags()) + if !debug { + l.SetOutput(ioutil.Discard) + } + + if err = validateChecksums(context.TODO(), packages.ToSet(), l); err != nil { + log.Fatal(err) + } + } }, } func init() { + validateCmd.Flags().BoolVar(&validateCmdFlags.checksums, "checksums", true, "validate checksums") rootCmd.AddCommand(validateCmd) } diff --git a/internal/pkg/types/v1alpha2/source.go b/internal/pkg/types/v1alpha2/source.go index b109fb0..c8632ed 100644 --- a/internal/pkg/types/v1alpha2/source.go +++ b/internal/pkg/types/v1alpha2/source.go @@ -5,8 +5,14 @@ package v1alpha2 import ( + "context" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" "errors" "fmt" + "io" + "net/http" "net/url" "github.com/hashicorp/go-multierror" @@ -53,13 +59,61 @@ func (source *Source) Validate() error { multiErr = multierror.Append(multiErr, errors.New("source.destination can't be empty")) } - if source.SHA256 == "" { + switch len(source.SHA256) { + case 0: multiErr = multierror.Append(multiErr, errors.New("source.sha256 can't be empty")) + case 64: //nolint:gomnd + // nothing + default: + multiErr = multierror.Append(multiErr, errors.New("source.sha256 should be 64 chars long")) } - if source.SHA512 == "" { + switch len(source.SHA512) { + case 0: multiErr = multierror.Append(multiErr, errors.New("source.sha512 can't be empty")) + case 128: //nolint:gomnd + // nothing + default: + multiErr = multierror.Append(multiErr, errors.New("source.sha512 should be 128 chars long")) } return multiErr.ErrorOrNil() } + +// ValidateChecksums downloads the source, validates checksums, +// and returns actual checksums and validation error, if any. +func (source *Source) ValidateChecksums(ctx context.Context) (string, string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", source.URL, nil) + if err != nil { + return "", "", err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", "", err + } + + defer resp.Body.Close() //nolint:errcheck + + s256 := sha256.New() + s512 := sha512.New() + + if _, err = io.Copy(io.MultiWriter(s256, s512), resp.Body); err != nil { + return "", "", err + } + + var ( + actualSHA256, actualSHA512 string + multiErr *multierror.Error + ) + + if actualSHA256 = hex.EncodeToString(s256.Sum(nil)); source.SHA256 != actualSHA256 { + multiErr = multierror.Append(multiErr, fmt.Errorf("source.sha256 does not match: expected %s, got %s", source.SHA256, actualSHA256)) + } + + if actualSHA512 = hex.EncodeToString(s512.Sum(nil)); source.SHA512 != actualSHA512 { + multiErr = multierror.Append(multiErr, fmt.Errorf("source.sha512 does not match: expected %s, got %s", source.SHA512, actualSHA512)) + } + + return actualSHA256, actualSHA512, multiErr.ErrorOrNil() +} diff --git a/internal/pkg/types/v1alpha2/source_test.go b/internal/pkg/types/v1alpha2/source_test.go new file mode 100644 index 0000000..8b58061 --- /dev/null +++ b/internal/pkg/types/v1alpha2/source_test.go @@ -0,0 +1,52 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package v1alpha2_test + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/talos-systems/bldr/internal/pkg/types/v1alpha2" +) + +//nolint:lll +func TestSourceValidateChecksums(t *testing.T) { + if testing.Short() { + t.Skip("skipping in -short mode") + } + + const ( + expectedSHA256 = "2aa5f088cbb332e73fc3def546800616b38d3bfe6b8713b8a6404060f22503e8" + expectedSHA512 = "ce64105ff71615f9d235cc7c8656b6409fc40cc90d15a28d355fadd9072d2eab842af379dd8bba0f1181715753143e4a07491e0f9e5f8df806327d7c95a34fae" + ) + + source := v1alpha2.Source{ + URL: "https://dl.google.com/go/go1.12.5.src.tar.gz", + Destination: "go1.12.5.src.tar.gz", + SHA256: expectedSHA256, + SHA512: expectedSHA512, + } + + actualSHA256, actualSHA512, err := source.ValidateChecksums(context.Background()) + require.NoError(t, err) + assert.Equal(t, expectedSHA256, actualSHA256) + assert.Equal(t, expectedSHA512, actualSHA512) + + source.SHA256 = strings.Repeat("0", 64) + source.SHA512 = strings.Repeat("1", 64) + + actualSHA256, actualSHA512, err = source.ValidateChecksums(context.Background()) + assert.EqualError(t, err, `2 errors occurred: + * source.sha256 does not match: expected 0000000000000000000000000000000000000000000000000000000000000000, got 2aa5f088cbb332e73fc3def546800616b38d3bfe6b8713b8a6404060f22503e8 + * source.sha512 does not match: expected 1111111111111111111111111111111111111111111111111111111111111111, got ce64105ff71615f9d235cc7c8656b6409fc40cc90d15a28d355fadd9072d2eab842af379dd8bba0f1181715753143e4a07491e0f9e5f8df806327d7c95a34fae + +`) + assert.Equal(t, expectedSHA256, actualSHA256) + assert.Equal(t, expectedSHA512, actualSHA512) +}