From 2d76af77e2f430de0809e8cd57b8ed9c8176f4a7 Mon Sep 17 00:00:00 2001 From: ianhundere <138915+ianhundere@users.noreply.github.com> Date: Tue, 17 Dec 2024 20:44:57 -0500 Subject: [PATCH] feat(certmaker): adds sigstore library for kms signatures, hashivault support, and adds testify back. Signed-off-by: ianhundere <138915+ianhundere@users.noreply.github.com> --- .../certificate_maker_test.go | 50 +-- go.mod | 5 +- go.sum | 6 - pkg/certmaker/certmaker.go | 210 +++++++--- pkg/certmaker/certmaker_test.go | 392 ++++++------------ pkg/certmaker/template_test.go | 51 +-- 6 files changed, 313 insertions(+), 401 deletions(-) diff --git a/cmd/certificate_maker/certificate_maker_test.go b/cmd/certificate_maker/certificate_maker_test.go index 89f9a05f7..caae5d334 100644 --- a/cmd/certificate_maker/certificate_maker_test.go +++ b/cmd/certificate_maker/certificate_maker_test.go @@ -18,10 +18,11 @@ package main import ( "os" "path/filepath" - "strings" "testing" "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetConfigValue(t *testing.T) { @@ -83,25 +84,19 @@ func TestGetConfigValue(t *testing.T) { defer os.Unsetenv(tt.envVar) } got := getConfigValue(tt.flagValue, tt.envVar) - if got != tt.want { - t.Errorf("got %v, want %v", got, tt.want) - } + assert.Equal(t, tt.want, got) }) } } func TestInitLogger(t *testing.T) { logger := initLogger() - if logger == nil { - t.Error("logger should not be nil") - } + require.NotNil(t, logger) } func TestRunCreate(t *testing.T) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) defer os.RemoveAll(tmpDir) rootTemplate := `{ @@ -136,13 +131,9 @@ func TestRunCreate(t *testing.T) { rootTmplPath := filepath.Join(tmpDir, "root-template.json") leafTmplPath := filepath.Join(tmpDir, "leaf-template.json") err = os.WriteFile(rootTmplPath, []byte(rootTemplate), 0600) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) err = os.WriteFile(leafTmplPath, []byte(leafTemplate), 0600) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) tests := []struct { name string @@ -273,13 +264,10 @@ func TestRunCreate(t *testing.T) { err := cmd.Execute() if tt.wantError { - if !strings.Contains(err.Error(), tt.errMsg) { - t.Errorf("error %q should contain %q", err.Error(), tt.errMsg) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) } else { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) } }) } @@ -299,9 +287,7 @@ func TestCreateCommand(t *testing.T) { cmd.Flags().StringVar(&leafKeyID, "leaf-key-id", "", "Leaf key ID") err := cmd.Execute() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) err = cmd.ParseFlags([]string{ "--kms-type", "awskms", @@ -309,12 +295,10 @@ func TestCreateCommand(t *testing.T) { "--root-key-id", "arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab", "--leaf-key-id", "arn:aws:kms:us-west-2:123456789012:key/9876fedc-ba98-7654-3210-fedcba987654", }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) if kmsType != "awskms" { - t.Errorf("got kmsType %v, want awskms", kmsType) + assert.Equal(t, "awskms", kmsType) } if kmsRegion != "us-west-2" { t.Errorf("got kmsRegion %v, want us-west-2", kmsRegion) @@ -330,13 +314,9 @@ func TestCreateCommand(t *testing.T) { func TestRootCommand(t *testing.T) { rootCmd.SetArgs([]string{"--help"}) err := rootCmd.Execute() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) rootCmd.SetArgs([]string{"unknown"}) err = rootCmd.Execute() - if err == nil { - t.Error("expected error for unknown command, got nil") - } + require.Error(t, err) } diff --git a/go.mod b/go.mod index 9954cb996..620d93f4f 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.19.0 github.com/spiffe/go-spiffe/v2 v2.4.0 + github.com/stretchr/testify v1.10.0 github.com/tink-crypto/tink-go-awskms/v2 v2.1.0 github.com/tink-crypto/tink-go-gcpkms/v2 v2.2.0 github.com/tink-crypto/tink-go/v2 v2.2.0 @@ -59,8 +60,6 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.3.1 // indirect @@ -88,6 +87,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chainguard-dev/clog v1.5.1 // indirect github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-jose/go-jose/v3 v3.0.3 // indirect github.com/go-logr/logr v1.4.2 // indirect @@ -126,6 +126,7 @@ require ( github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/go.sum b/go.sum index 3f6a4843a..b5d18c662 100644 --- a/go.sum +++ b/go.sum @@ -31,10 +31,6 @@ github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.0 h1:+m0M/LFxN43KvUL github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.0/go.mod h1:PwOyop78lveYMRs6oCxjiVyBdyCgIYH6XHIVZO9/SFQ= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= -github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0 h1:m/sWOGCREuSBqg2htVQTBY8nOZpyajYztF0vUvSZTuM= -github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0/go.mod h1:Pu5Zksi2KrU7LPbZbNINx6fuVrUp/ffvpxdDj+i8LeE= -github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 h1:FbH3BbSb4bvGluTesZZ+ttN/MDsnMmQP36OSnDuSXqw= -github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1/go.mod h1:9V2j0jn9jDEkCkv8w/bKTNppX/d0FVA1ud77xCIP4KA= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0 h1:7rKG7UmnrxX4N53TFhkYqjc+kVUZuw0fL8I3Fh+Ld9E= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.0/go.mod h1:Wjo+24QJVhhl/L7jy6w9yzFF2yDOf3cKECAa8ecf9vE= github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.0 h1:eXnN9kaS8TiDwXjoie3hMRLuwdUBUMW9KRgOqB3mCaw= @@ -408,8 +404,6 @@ go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= -go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= diff --git a/pkg/certmaker/certmaker.go b/pkg/certmaker/certmaker.go index cf5e03363..cd342b774 100644 --- a/pkg/certmaker/certmaker.go +++ b/pkg/certmaker/certmaker.go @@ -14,26 +14,49 @@ // // Package certmaker implements a certificate creation utility for Fulcio. -// It supports creating root, intermediate, and leaf certs using (AWS, GCP, Azure). +// It supports creating root, intermediate, and leaf certs using (AWS, GCP, Azure, HashiVault). package certmaker import ( + "bytes" "context" "crypto" "crypto/x509" "encoding/json" "encoding/pem" "fmt" + "io" "os" "strings" - "go.step.sm/crypto/kms/apiv1" - "go.step.sm/crypto/kms/awskms" - "go.step.sm/crypto/kms/azurekms" - "go.step.sm/crypto/kms/cloudkms" + "github.com/sigstore/sigstore/pkg/signature" + "github.com/sigstore/sigstore/pkg/signature/kms" + "github.com/sigstore/sigstore/pkg/signature/options" + + // Initialize AWS KMS provider + _ "github.com/sigstore/sigstore/pkg/signature/kms/aws" + // Initialize Azure KMS provider + _ "github.com/sigstore/sigstore/pkg/signature/kms/azure" + // Initialize GCP KMS provider + _ "github.com/sigstore/sigstore/pkg/signature/kms/gcp" + // Initialize HashiVault KMS provider + _ "github.com/sigstore/sigstore/pkg/signature/kms/hashivault" "go.step.sm/crypto/x509util" ) +type signerWrapper struct { + signature.SignerVerifier +} + +func (s signerWrapper) Public() crypto.PublicKey { + key, _ := s.PublicKey() + return key +} + +func (s signerWrapper) Sign(_ io.Reader, digest []byte, _ crypto.SignerOpts) ([]byte, error) { + return s.SignMessage(bytes.NewReader(digest), options.WithDigest(digest)) +} + // KMSConfig holds config for KMS providers. type KMSConfig struct { Type string @@ -45,15 +68,10 @@ type KMSConfig struct { } // InitKMS initializes KMS provider based on the given config, KMSConfig. -// Supports AWS KMS, Google Cloud KMS, and Azure Key Vault. -func InitKMS(ctx context.Context, config KMSConfig) (apiv1.KeyManager, error) { +func InitKMS(ctx context.Context, config KMSConfig) (signature.SignerVerifier, error) { if err := ValidateKMSConfig(config); err != nil { return nil, fmt.Errorf("invalid KMS configuration: %w", err) } - opts := apiv1.Options{ - Type: apiv1.Type(config.Type), - URI: "", - } // Falls back to LeafKeyID if root is not set keyID := config.RootKeyID @@ -61,42 +79,83 @@ func InitKMS(ctx context.Context, config KMSConfig) (apiv1.KeyManager, error) { keyID = config.LeafKeyID } + var sv signature.SignerVerifier + var err error + switch config.Type { case "awskms": - opts.URI = fmt.Sprintf("awskms:///%s?region=%s", keyID, config.Region) - return awskms.New(ctx, opts) - case "gcpkms": - opts.Type = apiv1.Type("cloudkms") - opts.URI = fmt.Sprintf("cloudkms:%s", keyID) - if credFile, ok := config.Options["credentials-file"]; ok { - if _, err := os.Stat(credFile); err != nil { - if os.IsNotExist(err) { - return nil, fmt.Errorf("credentials file not found: %s", credFile) - } - return nil, fmt.Errorf("error accessing credentials file: %w", err) - } - opts.URI += fmt.Sprintf("?credentials-file=%s", credFile) + ref := fmt.Sprintf("awskms:///%s", keyID) + if config.Region != "" { + os.Setenv("AWS_REGION", config.Region) + } + sv, err = kms.Get(ctx, ref, crypto.SHA256) + if err != nil { + return nil, fmt.Errorf("failed to initialize AWS KMS: %w", err) } - km, err := cloudkms.New(ctx, opts) + + case "gcpkms": + ref := fmt.Sprintf("gcpkms://%s", keyID) + sv, err = kms.Get(ctx, ref, crypto.SHA256) if err != nil { return nil, fmt.Errorf("failed to initialize GCP KMS: %w", err) } - return km, nil + case "azurekms": - opts.URI = keyID - if config.Options["tenant-id"] != "" { - opts.URI += fmt.Sprintf("?tenant-id=%s", config.Options["tenant-id"]) + keyURI := keyID + if strings.HasPrefix(keyID, "azurekms:name=") { + nameStart := strings.Index(keyID, "name=") + 5 + vaultIndex := strings.Index(keyID, ";vault=") + if vaultIndex != -1 { + keyName := strings.TrimSpace(keyID[nameStart:vaultIndex]) + vaultName := strings.TrimSpace(keyID[vaultIndex+7:]) + keyURI = fmt.Sprintf("azurekms://%s.vault.azure.net/%s", vaultName, keyName) + } + } + if config.Options != nil && config.Options["tenant-id"] != "" { + os.Setenv("AZURE_TENANT_ID", config.Options["tenant-id"]) + os.Setenv("AZURE_ADDITIONALLY_ALLOWED_TENANTS", "*") } - return azurekms.New(ctx, opts) + os.Setenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com/") + + sv, err = kms.Get(ctx, keyURI, crypto.SHA256) + if err != nil { + return nil, fmt.Errorf("failed to initialize Azure KMS: %w", err) + } + + case "hashivault": + keyURI := fmt.Sprintf("hashivault://%s", keyID) + if config.Options != nil { + if token := config.Options["token"]; token != "" { + os.Setenv("VAULT_TOKEN", token) + } + if addr := config.Options["address"]; addr != "" { + os.Setenv("VAULT_ADDR", addr) + } + } + + sv, err = kms.Get(ctx, keyURI, crypto.SHA256) + if err != nil { + return nil, fmt.Errorf("failed to initialize HashiVault KMS: %w", err) + } + default: return nil, fmt.Errorf("unsupported KMS type: %s", config.Type) } + + if err != nil { + return nil, fmt.Errorf("failed to get KMS signer: %w", err) + } + if sv == nil { + return nil, fmt.Errorf("KMS returned nil signer") + } + + return sv, nil } // CreateCertificates creates certificates using the provided KMS and templates. // It creates 3 certificates (root -> intermediate -> leaf) if intermediateKeyID is provided, // otherwise creates just 2 certs (root -> leaf). -func CreateCertificates(km apiv1.KeyManager, config KMSConfig, +func CreateCertificates(sv signature.SignerVerifier, config KMSConfig, rootTemplatePath, leafTemplatePath string, rootCertPath, leafCertPath string, intermediateKeyID, intermediateTemplatePath, intermediateCertPath string) error { @@ -107,14 +166,13 @@ func CreateCertificates(km apiv1.KeyManager, config KMSConfig, return fmt.Errorf("error parsing root template: %w", err) } - rootSigner, err := km.CreateSigner(&apiv1.CreateSignerRequest{ - SigningKey: config.RootKeyID, - }) + // Get public key from signer + rootPubKey, err := sv.PublicKey() if err != nil { - return fmt.Errorf("error creating root signer: %w", err) + return fmt.Errorf("error getting root public key: %w", err) } - rootCert, err := x509util.CreateCertificate(rootTmpl, rootTmpl, rootSigner.Public(), rootSigner) + rootCert, err := x509util.CreateCertificate(rootTmpl, rootTmpl, rootPubKey, signerWrapper{sv}) if err != nil { return fmt.Errorf("error creating root certificate: %w", err) } @@ -133,14 +191,20 @@ func CreateCertificates(km apiv1.KeyManager, config KMSConfig, return fmt.Errorf("error parsing intermediate template: %w", err) } - intermediateSigner, err := km.CreateSigner(&apiv1.CreateSignerRequest{ - SigningKey: intermediateKeyID, - }) + // Initialize new KMS for intermediate key + intermediateConfig := config + intermediateConfig.RootKeyID = intermediateKeyID + intermediateSV, err := InitKMS(context.Background(), intermediateConfig) + if err != nil { + return fmt.Errorf("error initializing intermediate KMS: %w", err) + } + + intermediatePubKey, err := intermediateSV.PublicKey() if err != nil { - return fmt.Errorf("error creating intermediate signer: %w", err) + return fmt.Errorf("error getting intermediate public key: %w", err) } - intermediateCert, err := x509util.CreateCertificate(intermediateTmpl, rootCert, intermediateSigner.Public(), rootSigner) + intermediateCert, err := x509util.CreateCertificate(intermediateTmpl, rootCert, intermediatePubKey, signerWrapper{sv}) if err != nil { return fmt.Errorf("error creating intermediate certificate: %w", err) } @@ -150,10 +214,10 @@ func CreateCertificates(km apiv1.KeyManager, config KMSConfig, } signingCert = intermediateCert - signingKey = intermediateSigner + signingKey = signerWrapper{intermediateSV} } else { signingCert = rootCert - signingKey = rootSigner + signingKey = signerWrapper{sv} } // Create leaf cert @@ -162,14 +226,20 @@ func CreateCertificates(km apiv1.KeyManager, config KMSConfig, return fmt.Errorf("error parsing leaf template: %w", err) } - leafSigner, err := km.CreateSigner(&apiv1.CreateSignerRequest{ - SigningKey: config.LeafKeyID, - }) + // Initialize new KMS for leaf key + leafConfig := config + leafConfig.RootKeyID = config.LeafKeyID + leafSV, err := InitKMS(context.Background(), leafConfig) if err != nil { - return fmt.Errorf("error creating leaf signer: %w", err) + return fmt.Errorf("error initializing leaf KMS: %w", err) } - leafCert, err := x509util.CreateCertificate(leafTmpl, signingCert, leafSigner.Public(), signingKey) + leafPubKey, err := leafSV.PublicKey() + if err != nil { + return fmt.Errorf("error getting leaf public key: %w", err) + } + + leafCert, err := x509util.CreateCertificate(leafTmpl, signingCert, leafPubKey, signingKey) if err != nil { return fmt.Errorf("error creating leaf certificate: %w", err) } @@ -229,7 +299,8 @@ func ValidateKMSConfig(config KMSConfig) error { if keyID == "" { return nil } - if strings.HasPrefix(keyID, "arn:aws:kms:") { + switch { + case strings.HasPrefix(keyID, "arn:aws:kms:"): parts := strings.Split(keyID, ":") if len(parts) < 6 { return fmt.Errorf("invalid AWS KMS ARN format for %s", keyType) @@ -237,11 +308,11 @@ func ValidateKMSConfig(config KMSConfig) error { if parts[3] != config.Region { return fmt.Errorf("region in ARN (%s) does not match configured region (%s)", parts[3], config.Region) } - } else if strings.HasPrefix(keyID, "alias/") { + case strings.HasPrefix(keyID, "alias/"): if strings.TrimPrefix(keyID, "alias/") == "" { return fmt.Errorf("alias name cannot be empty for %s", keyType) } - } else { + default: return fmt.Errorf("awskms %s must start with 'arn:aws:kms:' or 'alias/'", keyType) } return nil @@ -327,6 +398,43 @@ func ValidateKMSConfig(config KMSConfig) error { return err } + case "hashivault": + // HashiVault KMS validation + if config.Options == nil { + return fmt.Errorf("options map is required for HashiVault KMS") + } + if config.Options["address"] == "" { + return fmt.Errorf("address is required for HashiVault KMS") + } + if config.Options["token"] == "" { + return fmt.Errorf("token is required for HashiVault KMS") + } + validateHashiVaultKeyID := func(keyID, keyType string) error { + if keyID == "" { + return nil + } + parts := strings.Split(keyID, "/") + if len(parts) < 3 { + return fmt.Errorf("hashivault %s must be in format: transit/keys/keyname", keyType) + } + if parts[0] != "transit" || parts[1] != "keys" { + return fmt.Errorf("hashivault %s must start with 'transit/keys/'", keyType) + } + if parts[2] == "" { + return fmt.Errorf("key name cannot be empty for %s", keyType) + } + return nil + } + if err := validateHashiVaultKeyID(config.RootKeyID, "RootKeyID"); err != nil { + return err + } + if err := validateHashiVaultKeyID(config.IntermediateKeyID, "IntermediateKeyID"); err != nil { + return err + } + if err := validateHashiVaultKeyID(config.LeafKeyID, "LeafKeyID"); err != nil { + return err + } + default: return fmt.Errorf("unsupported KMS type: %s", config.Type) } diff --git a/pkg/certmaker/certmaker_test.go b/pkg/certmaker/certmaker_test.go index 2150e18a3..6fca442ee 100644 --- a/pkg/certmaker/certmaker_test.go +++ b/pkg/certmaker/certmaker_test.go @@ -33,6 +33,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/kms/apiv1" ) @@ -111,9 +113,7 @@ func (m *mockInvalidKMS) Close() error { func TestParseTemplate(t *testing.T) { tmpFile, err := os.CreateTemp("", "cert-template-*.json") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) defer os.Remove(tmpFile.Name()) templateContent := `{ @@ -139,23 +139,13 @@ func TestParseTemplate(t *testing.T) { }` err = os.WriteFile(tmpFile.Name(), []byte(templateContent), 0600) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) tmpl, err := ParseTemplate(tmpFile.Name(), nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if tmpl.Subject.CommonName != "Test CA" { - t.Errorf("got %v, want Test CA", tmpl.Subject.CommonName) - } - if !tmpl.IsCA { - t.Errorf("got %v, want true", tmpl.IsCA) - } - if tmpl.MaxPathLen != 0 { - t.Errorf("got %v, want 0", tmpl.MaxPathLen) - } + require.NoError(t, err) + assert.Equal(t, "Test CA", tmpl.Subject.CommonName) + assert.True(t, tmpl.IsCA) + assert.Equal(t, 0, tmpl.MaxPathLen) } func TestCreateCertificates(t *testing.T) { @@ -168,9 +158,7 @@ func TestCreateCertificates(t *testing.T) { name: "successful certificate creation", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) rootTemplate := filepath.Join(tmpDir, "root.json") err = os.WriteFile(rootTemplate, []byte(`{ @@ -181,9 +169,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write root template: %v", err) - } + require.NoError(t, err) leafTemplate := filepath.Join(tmpDir, "leaf.json") err = os.WriteFile(leafTemplate, []byte(`{ @@ -194,9 +180,10 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write leaf template: %v", err) - } + require.NoError(t, err) + + outDir := filepath.Join(tmpDir, "out") + require.NoError(t, os.MkdirAll(outDir, 0755)) config := KMSConfig{ Type: "awskms", @@ -212,9 +199,7 @@ func TestCreateCertificates(t *testing.T) { name: "invalid template path", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) config := KMSConfig{ Type: "awskms", @@ -231,9 +216,7 @@ func TestCreateCertificates(t *testing.T) { name: "invalid KMS configuration", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) rootTemplate := filepath.Join(tmpDir, "root.json") err = os.WriteFile(rootTemplate, []byte(`{ @@ -244,9 +227,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write root template: %v", err) - } + require.NoError(t, err) leafTemplate := filepath.Join(tmpDir, "leaf.json") err = os.WriteFile(leafTemplate, []byte(`{ @@ -257,9 +238,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write leaf template: %v", err) - } + require.NoError(t, err) config := KMSConfig{ Type: "invalid", @@ -275,9 +254,7 @@ func TestCreateCertificates(t *testing.T) { name: "with intermediate certificate", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) rootTemplate := filepath.Join(tmpDir, "root.json") err = os.WriteFile(rootTemplate, []byte(`{ @@ -288,9 +265,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write root template: %v", err) - } + require.NoError(t, err) intermediateTemplate := filepath.Join(tmpDir, "intermediate.json") err = os.WriteFile(intermediateTemplate, []byte(`{ @@ -300,9 +275,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write intermediate template: %v", err) - } + require.NoError(t, err) leafTemplate := filepath.Join(tmpDir, "leaf.json") err = os.WriteFile(leafTemplate, []byte(`{ @@ -313,9 +286,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write leaf template: %v", err) - } + require.NoError(t, err) config := KMSConfig{ Type: "awskms", @@ -332,9 +303,7 @@ func TestCreateCertificates(t *testing.T) { name: "invalid intermediate template", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) rootTemplate := filepath.Join(tmpDir, "root.json") err = os.WriteFile(rootTemplate, []byte(`{ @@ -345,9 +314,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write root template: %v", err) - } + require.NoError(t, err) leafTemplate := filepath.Join(tmpDir, "leaf.json") err = os.WriteFile(leafTemplate, []byte(`{ @@ -358,9 +325,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write leaf template: %v", err) - } + require.NoError(t, err) config := KMSConfig{ Type: "awskms", @@ -378,9 +343,7 @@ func TestCreateCertificates(t *testing.T) { name: "invalid intermediate key", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) rootTemplate := filepath.Join(tmpDir, "root.json") err = os.WriteFile(rootTemplate, []byte(`{ @@ -391,9 +354,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write root template: %v", err) - } + require.NoError(t, err) intermediateTemplate := filepath.Join(tmpDir, "intermediate.json") err = os.WriteFile(intermediateTemplate, []byte(`{ @@ -403,9 +364,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write intermediate template: %v", err) - } + require.NoError(t, err) leafTemplate := filepath.Join(tmpDir, "leaf.json") err = os.WriteFile(leafTemplate, []byte(`{ @@ -416,9 +375,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write leaf template: %v", err) - } + require.NoError(t, err) config := KMSConfig{ Type: "awskms", @@ -436,18 +393,14 @@ func TestCreateCertificates(t *testing.T) { name: "error creating root certificate", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) rootTemplate := filepath.Join(tmpDir, "root.json") err = os.WriteFile(rootTemplate, []byte(`{ "subject": {}, "issuer": {} }`), 0600) - if err != nil { - t.Fatalf("Failed to write root template: %v", err) - } + require.NoError(t, err) config := KMSConfig{ Type: "awskms", @@ -464,9 +417,7 @@ func TestCreateCertificates(t *testing.T) { name: "error creating leaf certificate", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) rootTemplate := filepath.Join(tmpDir, "root.json") err = os.WriteFile(rootTemplate, []byte(`{ @@ -477,18 +428,14 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write root template: %v", err) - } + require.NoError(t, err) leafTemplate := filepath.Join(tmpDir, "leaf.json") err = os.WriteFile(leafTemplate, []byte(`{ "subject": {}, "issuer": {} }`), 0600) - if err != nil { - t.Fatalf("Failed to write leaf template: %v", err) - } + require.NoError(t, err) config := KMSConfig{ Type: "awskms", @@ -505,9 +452,7 @@ func TestCreateCertificates(t *testing.T) { name: "error writing certificates", setup: func(t *testing.T) (string, KMSConfig, apiv1.KeyManager) { tmpDir, err := os.MkdirTemp("", "cert-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) rootTemplate := filepath.Join(tmpDir, "root.json") err = os.WriteFile(rootTemplate, []byte(`{ @@ -518,9 +463,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write root template: %v", err) - } + require.NoError(t, err) leafTemplate := filepath.Join(tmpDir, "leaf.json") err = os.WriteFile(leafTemplate, []byte(`{ @@ -531,9 +474,7 @@ func TestCreateCertificates(t *testing.T) { "notBefore": "2024-01-01T00:00:00Z", "notAfter": "2025-01-01T00:00:00Z" }`), 0600) - if err != nil { - t.Fatalf("Failed to write leaf template: %v", err) - } + require.NoError(t, err) config := KMSConfig{ Type: "awskms", @@ -544,9 +485,7 @@ func TestCreateCertificates(t *testing.T) { outDir := filepath.Join(tmpDir, "out") err = os.MkdirAll(outDir, 0444) - if err != nil { - t.Fatalf("Failed to create output directory: %v", err) - } + require.NoError(t, err) return tmpDir, config, newMockKMSProvider() }, @@ -561,9 +500,7 @@ func TestCreateCertificates(t *testing.T) { outDir := filepath.Join(tmpDir, "out") err := os.MkdirAll(outDir, 0755) - if err != nil { - t.Fatalf("Failed to create output directory: %v", err) - } + require.NoError(t, err) err = CreateCertificates(kms, config, filepath.Join(tmpDir, "root.json"), @@ -575,15 +512,15 @@ func TestCreateCertificates(t *testing.T) { filepath.Join(outDir, "intermediate.crt")) if tt.wantError != "" { - if err == nil { - t.Error("Expected error but got none") - } else if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("Expected error containing %q, got %q", tt.wantError, err.Error()) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) } else { - if err != nil { - t.Errorf("Unexpected error: %v", err) - } + require.NoError(t, err) + // Verify certificates were created + rootCertPath := filepath.Join(outDir, "root.crt") + leafCertPath := filepath.Join(outDir, "leaf.crt") + require.FileExists(t, rootCertPath) + require.FileExists(t, leafCertPath) } }) } @@ -591,9 +528,7 @@ func TestCreateCertificates(t *testing.T) { func TestWriteCertificateToFile(t *testing.T) { tmpDir, err := os.MkdirTemp("", "cert-write-test-*") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) defer os.RemoveAll(tmpDir) cert := &x509.Certificate{ @@ -638,26 +573,15 @@ func TestWriteCertificateToFile(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := WriteCertificateToFile(tt.cert, tt.path) if tt.wantError { - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.errMsg) { - t.Errorf("error %q should contain %q", err.Error(), tt.errMsg) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) content, err := os.ReadFile(tt.path) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) block, _ := pem.Decode(content) - if block == nil { - t.Errorf("failed to decode PEM block") - } - if block.Type != "CERTIFICATE" { - t.Errorf("got %v, want CERTIFICATE", block.Type) - } + require.NotNil(t, block) + assert.Equal(t, "CERTIFICATE", block.Type) } }) } @@ -731,41 +655,31 @@ func verifyDirectChain(t *testing.T, rootPath, leafPath string) { Roots: rootPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning}, }) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) } func loadCertificate(t *testing.T, path string) *x509.Certificate { data, err := os.ReadFile(path) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) block, _ := pem.Decode(data) - if block == nil { - t.Fatalf("failed to decode PEM block") - } + require.NotNil(t, block) cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - t.Fatalf("error parsing certificate: %v", err) - } + require.NoError(t, err) return cert } func TestValidateKMSConfig(t *testing.T) { tests := []struct { - name string - config KMSConfig - wantErr bool - wantErrMsg string + name string + config KMSConfig + wantError string }{ { - name: "empty_config", - config: KMSConfig{}, - wantErr: true, - wantErrMsg: "KMS type cannot be empty", + name: "empty_config", + config: KMSConfig{}, + wantError: "KMS type cannot be empty", }, { name: "missing_key_ids", @@ -773,8 +687,7 @@ func TestValidateKMSConfig(t *testing.T) { Type: "awskms", Region: "us-west-2", }, - wantErr: true, - wantErrMsg: "at least one of RootKeyID or LeafKeyID must be specified", + wantError: "at least one of RootKeyID or LeafKeyID must be specified", }, { name: "aws_kms_missing_region", @@ -782,8 +695,7 @@ func TestValidateKMSConfig(t *testing.T) { Type: "awskms", RootKeyID: "arn:aws:kms:us-west-2:123456789012:key/12345678-1234-1234-1234-1234567890ab", }, - wantErr: true, - wantErrMsg: "region is required for AWS KMS", + wantError: "region is required for AWS KMS", }, { name: "aws_kms_invalid_root_key_format", @@ -792,8 +704,7 @@ func TestValidateKMSConfig(t *testing.T) { Region: "us-west-2", RootKeyID: "invalid-key-format", }, - wantErr: true, - wantErrMsg: "awskms RootKeyID must start with 'arn:aws:kms:' or 'alias/'", + wantError: "awskms RootKeyID must start with 'arn:aws:kms:' or 'alias/'", }, { name: "gcp_kms_invalid_root_key_format", @@ -801,8 +712,7 @@ func TestValidateKMSConfig(t *testing.T) { Type: "gcpkms", RootKeyID: "invalid-key-id", }, - wantErr: true, - wantErrMsg: "must start with 'projects/'", + wantError: "must start with 'projects/'", }, { name: "azure_kms_missing_tenant_id", @@ -813,8 +723,7 @@ func TestValidateKMSConfig(t *testing.T) { "vault": "test-vault", }, }, - wantErr: true, - wantErrMsg: "tenant-id is required for Azure KMS", + wantError: "tenant-id is required for Azure KMS", }, { name: "azure_kms_missing_vault", @@ -825,8 +734,7 @@ func TestValidateKMSConfig(t *testing.T) { "tenant-id": "test-tenant", }, }, - wantErr: true, - wantErrMsg: "azurekms RootKeyID must contain ';vault=' parameter", + wantError: "azurekms RootKeyID must contain ';vault=' parameter", }, { name: "azure_kms_missing_options", @@ -834,8 +742,7 @@ func TestValidateKMSConfig(t *testing.T) { Type: "azurekms", RootKeyID: "azurekms:name=mykey", }, - wantErr: true, - wantErrMsg: "options map is required for Azure KMS", + wantError: "options map is required for Azure KMS", }, { name: "unsupported_kms_type", @@ -843,8 +750,7 @@ func TestValidateKMSConfig(t *testing.T) { Type: "unsupported", RootKeyID: "key-id", }, - wantErr: true, - wantErrMsg: "unsupported KMS type: unsupported", + wantError: "unsupported KMS type: unsupported", }, { name: "aws_kms_invalid_arn_format", @@ -853,8 +759,7 @@ func TestValidateKMSConfig(t *testing.T) { Region: "us-west-2", RootKeyID: "arn:aws:kms:us-west-2:invalid", }, - wantErr: true, - wantErrMsg: "invalid AWS KMS ARN format for RootKeyID", + wantError: "invalid AWS KMS ARN format for RootKeyID", }, { name: "aws_kms_region_mismatch", @@ -863,8 +768,7 @@ func TestValidateKMSConfig(t *testing.T) { Region: "us-west-2", RootKeyID: "arn:aws:kms:us-east-1:123456789012:key/test-key", }, - wantErr: true, - wantErrMsg: "region in ARN (us-east-1) does not match configured region (us-west-2)", + wantError: "region in ARN (us-east-1) does not match configured region (us-west-2)", }, { name: "aws_kms_empty_alias", @@ -873,8 +777,7 @@ func TestValidateKMSConfig(t *testing.T) { Region: "us-west-2", RootKeyID: "alias/", }, - wantErr: true, - wantErrMsg: "alias name cannot be empty for RootKeyID", + wantError: "alias name cannot be empty for RootKeyID", }, { name: "azure_kms_empty_key_name", @@ -885,8 +788,7 @@ func TestValidateKMSConfig(t *testing.T) { "tenant-id": "test-tenant", }, }, - wantErr: true, - wantErrMsg: "key name cannot be empty for RootKeyID", + wantError: "key name cannot be empty for RootKeyID", }, { name: "azure_kms_empty_vault_name", @@ -897,24 +799,18 @@ func TestValidateKMSConfig(t *testing.T) { "tenant-id": "test-tenant", }, }, - wantErr: true, - wantErrMsg: "vault name cannot be empty for RootKeyID", + wantError: "vault name cannot be empty for RootKeyID", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := ValidateKMSConfig(tt.config) - if tt.wantErr { - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.wantErrMsg) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantErrMsg) - } + if tt.wantError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) } }) } @@ -1298,15 +1194,10 @@ func TestValidateTemplate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateTemplate(tt.tmpl, tt.parent, tt.certType) if tt.wantError != "" { - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantError) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) } }) } @@ -1357,15 +1248,10 @@ func TestValidateTemplateKeyUsageCombinations(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateTemplate(tt.tmpl, tt.parent, tt.certType) if tt.wantError != "" { - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantError) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) } }) } @@ -1416,15 +1302,10 @@ func TestValidateLeafCertificateKeyUsage(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateTemplate(tt.tmpl, tt.parent, "leaf") if tt.wantError { - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.errMsg) { - t.Errorf("error %q should contain %q", err.Error(), tt.errMsg) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) } }) } @@ -1448,9 +1329,7 @@ func TestValidateTemplatePath(t *testing.T) { path: "template.txt", setup: func() string { f, err := os.CreateTemp("", "template.txt") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) return f.Name() }, wantError: "must have .json extension", @@ -1460,13 +1339,9 @@ func TestValidateTemplatePath(t *testing.T) { path: "invalid.json", setup: func() string { f, err := os.CreateTemp("", "template*.json") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) err = os.WriteFile(f.Name(), []byte("invalid json"), 0600) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) return f.Name() }, wantError: "invalid JSON", @@ -1476,13 +1351,9 @@ func TestValidateTemplatePath(t *testing.T) { path: "valid.json", setup: func() string { f, err := os.CreateTemp("", "template*.json") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) err = os.WriteFile(f.Name(), []byte(`{"key": "value"}`), 0600) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) return f.Name() }, }, @@ -1498,15 +1369,10 @@ func TestValidateTemplatePath(t *testing.T) { err := ValidateTemplatePath(path) if tt.wantError != "" { - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantError) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) } }) } @@ -1548,15 +1414,10 @@ func TestGCPKMSValidation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateKMSConfig(tt.config) if tt.wantError != "" { - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantError) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) } }) } @@ -1612,6 +1473,7 @@ func TestAzureKMSValidation(t *testing.T) { config: KMSConfig{ Type: "azurekms", RootKeyID: "azurekms:name=mykey;vault=myvault", + LeafKeyID: "azurekms:name=leaf-key;vault=test-vault", Options: map[string]string{ "tenant-id": "test-tenant", }, @@ -1623,15 +1485,10 @@ func TestAzureKMSValidation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateKMSConfig(tt.config) if tt.wantError != "" { - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantError) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } + require.NoError(t, err) } }) } @@ -1682,10 +1539,11 @@ func TestInitKMSErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := InitKMS(ctx, tt.config) - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantError) + if tt.wantError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) + } else { + require.NoError(t, err) } }) } @@ -1693,15 +1551,11 @@ func TestInitKMSErrors(t *testing.T) { func TestInitKMS(t *testing.T) { tmpDir, err := os.MkdirTemp("", "kms-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } + require.NoError(t, err) defer os.RemoveAll(tmpDir) privKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("Failed to generate private key: %v", err) - } + require.NoError(t, err) privKeyPEM := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", @@ -1721,9 +1575,7 @@ func TestInitKMS(t *testing.T) { "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test@test-project.iam.gserviceaccount.com" }`, string(privKeyPEM))), 0600) - if err != nil { - t.Fatalf("Failed to write credentials file: %v", err) - } + require.NoError(t, err) ctx := context.Background() tests := []struct { @@ -1783,21 +1635,11 @@ func TestInitKMS(t *testing.T) { t.Run(tt.name, func(t *testing.T) { km, err := InitKMS(ctx, tt.config) if tt.wantError { - if err == nil { - t.Error("expected error but got nil") - } else if !strings.Contains(err.Error(), tt.errMsg) { - t.Errorf("error %q should contain %q", err.Error(), tt.errMsg) - } - if km != nil { - t.Error("expected nil KMS but got non-nil") - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) } else { - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if km == nil { - t.Error("expected non-nil KMS but got nil") - } + require.NoError(t, err) + require.NotNil(t, km) } }) } diff --git a/pkg/certmaker/template_test.go b/pkg/certmaker/template_test.go index 055fbbdbf..14d05826d 100644 --- a/pkg/certmaker/template_test.go +++ b/pkg/certmaker/template_test.go @@ -21,6 +21,9 @@ import ( "os" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestValidateTemplateFields(t *testing.T) { @@ -221,9 +224,10 @@ func TestValidateTemplateFields(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := ValidateTemplate(tt.tmpl, tt.parent, tt.certType) if tt.wantError != "" { - if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantError) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) + } else { + require.NoError(t, err) } }) } @@ -231,7 +235,8 @@ func TestValidateTemplateFields(t *testing.T) { func TestParseTemplateErrors(t *testing.T) { tests := []struct { - name string + name string + content string wantError string }{ @@ -265,22 +270,15 @@ func TestParseTemplateErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tmpFile, err := os.CreateTemp("", "cert-template-*.json") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) defer os.Remove(tmpFile.Name()) err = os.WriteFile(tmpFile.Name(), []byte(tt.content), 0600) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, err) _, err = ParseTemplate(tmpFile.Name(), nil) - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), tt.wantError) { - t.Errorf("error %q should contain %q", err.Error(), tt.wantError) - } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) }) } @@ -305,26 +303,15 @@ func TestInvalidCertificateType(t *testing.T) { } err := ValidateTemplate(tmpl, nil, "invalid") - if err == nil { - t.Errorf("expected error, got nil") - } else if !strings.Contains(err.Error(), "invalid certificate type") { - t.Errorf("error %q should contain %q", err.Error(), "invalid certificate type") - } + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid certificate type") } func TestContainsExtKeyUsage(t *testing.T) { - if containsExtKeyUsage(nil, "CodeSigning") { - t.Error("empty list (nil) should return false") - } - if containsExtKeyUsage([]string{}, "CodeSigning") { - t.Error("empty list should return false") - } - if !containsExtKeyUsage([]string{"CodeSigning"}, "CodeSigning") { - t.Error("should find matching usage") - } - if containsExtKeyUsage([]string{"OtherUsage"}, "CodeSigning") { - t.Error("should not find non-matching usage") - } + assert.False(t, containsExtKeyUsage(nil, "CodeSigning"), "empty list (nil) should return false") + assert.False(t, containsExtKeyUsage([]string{}, "CodeSigning"), "empty list should return false") + assert.True(t, containsExtKeyUsage([]string{"CodeSigning"}, "CodeSigning"), "should find matching usage") + assert.False(t, containsExtKeyUsage([]string{"OtherUsage"}, "CodeSigning"), "should not find non-matching usage") } func containsExtKeyUsage(usages []string, target string) bool {