From 4430c6126938c99adf64a61b1f6946da53a89962 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 28 May 2024 14:21:50 +0000 Subject: [PATCH] removed unecessary comments, added tests, extracted some duplication --- pkg/apiserver/apiserver.go | 11 ++- pkg/apiserver/webhooks.go | 94 +++++++----------------- pkg/apiserver/webhooks_test.go | 126 +++++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 74 deletions(-) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index ded3f0b9..05e909d6 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -289,14 +289,11 @@ func (s *PorchServer) Run(ctx context.Context) error { certStorageDir, found := os.LookupEnv("CERT_STORAGE_DIR") if found && strings.TrimSpace(certStorageDir) != "" { - // check if we should use cert manager webhook setup - useCertManWebhook := false - // if we dont find the env var or the var is not true - CertManWebhook, found := os.LookupEnv("CERT_MAN_WEBHOOKS") - if found || CertManWebhook == "true"{ - useCertManWebhook = true + useCertMan := false + if _, found := os.LookupEnv("USE_CERT_MAN_FOR_WEBHOOK"); found { + useCertMan = true } - if err := setupWebhooks(ctx, webhookNs, certStorageDir, useCertManWebhook); err != nil { + if err := setupWebhooks(ctx, webhookNs, certStorageDir, useCertMan); err != nil { klog.Errorf("%v\n", err) return err } diff --git a/pkg/apiserver/webhooks.go b/pkg/apiserver/webhooks.go index c4daa7b5..dc2061e0 100644 --- a/pkg/apiserver/webhooks.go +++ b/pkg/apiserver/webhooks.go @@ -30,7 +30,7 @@ import ( "net/http" "os" "path/filepath" - "sync" + "time" "github.com/fsnotify/fsnotify" @@ -55,7 +55,6 @@ const ( ) var ( - mu sync.Mutex cert tls.Certificate certModTime time.Time ) @@ -69,7 +68,7 @@ func setupWebhooks(ctx context.Context, webhookNs string, certStorageDir string, if err := createValidatingWebhook(ctx, webhookNs, caBytes); err != nil { return err } - + } if err := runWebhookServer(certStorageDir, useCertManWebhook); err != nil { return err @@ -238,19 +237,13 @@ func createValidatingWebhook(ctx context.Context, webhookNs string, caCert []byt return nil } -// load the certificate from the secret and update when secret cert data changes e.g. +// load the certificate & keep note of time loaded for reload on new cert details func loadCertificate(certPath, keyPath string) (tls.Certificate, error) { - // get info about cert manager certificate secret mounted as a volume on the porch server pod certInfo, err := os.Stat(certPath) if err != nil { return tls.Certificate{}, err } - // if the last time this secret was modified was after the last time we loaded its cert files - // we lock access to the mount path and load the keypair from the files in our tls then release the lock - // set this new loaded cert as our current cert and note the modtime for next reload if certInfo.ModTime().After(certModTime) { - mu.Lock() - defer mu.Unlock() newCert, err := tls.LoadX509KeyPair(certPath, keyPath) if err != nil { return tls.Certificate{}, err @@ -269,7 +262,11 @@ func watchCertificates(directory, certFile, keyFile string) { klog.Fatalf("failed to start certificate watcher: %v", err) } defer watcher.Close() - + // start the watcher with the dir to watch + err = watcher.Add(directory) + if err != nil { + klog.Errorf("Error in running watcher: %v", err) + } // if the watcher notices any changes on the mount dir of the secret such as creations or writes to the files in this dir // attempt to load tls cert using the new cert and key files provided and log output done := make(chan bool) @@ -296,72 +293,35 @@ func watchCertificates(directory, certFile, keyFile string) { } } }() - // start the watcher with the dir to watch - err = watcher.Add(directory) - if err != nil { - klog.Fatalf("Error in running watcher: %v", err) - } - <-done } -// return current cert -func getCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { - mu.Lock() - defer mu.Unlock() - return &cert, nil -} func runWebhookServer(certStorageDir string, useCertManWebhook bool) error { certFile := filepath.Join(certStorageDir, "tls.crt") keyFile := filepath.Join(certStorageDir, "tls.key") - // load the cert for the first time - + cert, err := loadCertificate(certFile, keyFile) + if err != nil { + klog.Errorf("failed to load certificate: %v", err) + return err + } if useCertManWebhook { - _, err := loadCertificate(certFile, keyFile) - if err != nil { - klog.Errorf("failed to load certificate: %v", err) - } - // Start watching the certificate files for changes - // watch for changes in directory where secret is mounted go watchCertificates(certStorageDir, certFile, keyFile) - - klog.Infoln("Starting webhook server") - http.HandleFunc(serverEndpoint, validateDeletion) - server := http.Server{ - Addr: fmt.Sprintf(":%d", webhookServicePort), - TLSConfig: &tls.Config{ - GetCertificate: getCertificate, - }, - } - go func() { - err = server.ListenAndServeTLS("", "") - if err != nil { - klog.Errorf("could not start server: %v", err) - } - }() - return err - - } else { - cert, err := tls.LoadX509KeyPair(certFile, keyFile) + } + klog.Infoln("Starting webhook server") + http.HandleFunc(serverEndpoint, validateDeletion) + server := http.Server{ + Addr: fmt.Sprintf(":%d", webhookServicePort), + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + } + go func() { + err = server.ListenAndServeTLS("", "") if err != nil { - return err - } - klog.Infoln("Starting webhook server") - http.HandleFunc(serverEndpoint, validateDeletion) - server := http.Server{ - Addr: fmt.Sprintf(":%d", webhookServicePort), - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - }, + klog.Errorf("could not start server: %v", err) } - go func() { - err = server.ListenAndServeTLS("", "") - if err != nil { - klog.Errorf("could not start server: %v", err) - } - }() - return err - } + }() + return err } func validateDeletion(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/apiserver/webhooks_test.go b/pkg/apiserver/webhooks_test.go index 84aad665..512a2835 100644 --- a/pkg/apiserver/webhooks_test.go +++ b/pkg/apiserver/webhooks_test.go @@ -17,6 +17,7 @@ package apiserver import ( "bytes" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" @@ -24,6 +25,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/stretchr/testify/require" admissionv1 "k8s.io/api/admission/v1" @@ -58,6 +60,130 @@ func TestCreateCerts(t *testing.T) { require.True(t, strings.HasSuffix(keyStr, "\n-----END RSA PRIVATE KEY-----")) } +func TestLoadCertificate(t *testing.T) { + + //what do i need to test. + // first create dummy certs for testing + dir := t.TempDir() + defer func() { + require.NoError(t, os.RemoveAll(dir)) + }() + _, err := createCerts("", dir) + require.NoError(t, err) + + //1. test file that cannot be os.stated or causes that function to fail + _, err1 := loadCertificate(filepath.Join(dir, "nonexistingcrtfile.key"), filepath.Join(dir, "nonexistingkeyfile.key")) + require.Error(t, err1) + // reseting back to 0 + certModTime = time.Time{} + //2. test happy path of os.stat and continue to next error + //3. test loading good cert happy path + keypath := filepath.Join(dir, "tls.key") + crtpath := filepath.Join(dir, "tls.crt") + + _, err2 := loadCertificate(crtpath, keypath) + require.NoError(t, err2) + certModTime = time.Time{} + + //4. test loading faulty cert error + data := []byte("Hello, World!") + + writeErr := os.WriteFile(keypath, data, 0644) + require.NoError(t, writeErr) + writeErr2 := os.WriteFile(crtpath, data, 0644) + require.NoError(t, writeErr2) + + _, err3 := loadCertificate(filepath.Join(dir, "tls.crt"), filepath.Join(dir, "tls.key")) + require.Error(t, err3) + certModTime = time.Time{} +} + +func captureStderr(f func()) string { + r, w, _ := os.Pipe() + stderr := os.Stderr + os.Stderr = w + outC := make(chan string) + + // Copy the output in a separate goroutine so printing can't block indefinitely. + go func() { + var buf bytes.Buffer + io.Copy(&buf, r) + outC <- buf.String() + }() + + // Run the provided function and capture the stderr output + f() + + // Restore the original stderr and close the write-end of the pipe so the goroutine will exit + os.Stderr = stderr + w.Close() + out := <-outC + + return out +} + +func TestWatchCertificates(t *testing.T) { + // method for processing klog output + assertLogMessages := func(log string) error { + if len(log) > 0 { + if log[0] == 'E' || log[0] == 'W' || log[0] == 'F' { + return errors.New("Error Occured in Watcher") + } + } + return nil + } + // Set up the temp directory with dummy certificate files + dir := t.TempDir() + defer func() { + require.NoError(t, os.RemoveAll(dir)) + }() + _, err := createCerts("", dir) + require.NoError(t, err) + + keyFile := filepath.Join(dir, "tls.key") + certFile := filepath.Join(dir, "tls.crt") + + // firstly test error occuring from invalid entity for watcher to watch. aka invalid dir + // we expect an error + go watchCertificates("Dummy Directory that does not exist", certFile, keyFile) + + invalid_watch_entity_logs := captureStderr(func() { + time.Sleep(100 * time.Millisecond) // Give some time for the logs to be flushed + }) + t.Log(invalid_watch_entity_logs) + err = assertLogMessages(invalid_watch_entity_logs) + require.Error(t, err) + + go watchCertificates(dir, certFile, keyFile) + time.Sleep(1 * time.Second) + + //create file to trigger change but not alter the certificate contents + //should trigger reload and certificate reloaded successfully + newFilePath := filepath.Join(dir, "new_temp_file.txt") + _, err = os.Create(newFilePath) + require.NoError(t, err) + + valid_reload_logs := captureStderr(func() { + time.Sleep(100 * time.Millisecond) // Give some time for the logs to be flushed + }) + t.Log(valid_reload_logs) + err = assertLogMessages(valid_reload_logs) + require.NoError(t, err) + + // Modify the certificate file to trigger a file system event + // should cause an error log since cert contents are not valid anymore + certModTime = time.Time{} + err = os.WriteFile(certFile, []byte("dummy text"), 0660) + require.NoError(t, err) + + invalid_reload_logs := captureStderr(func() { + time.Sleep(100 * time.Millisecond) // Give some time for the logs to be flushed + }) + t.Log(invalid_reload_logs) + err = assertLogMessages(invalid_reload_logs) + require.Error(t, err) +} + func TestValidateDeletion(t *testing.T) { t.Run("invalid content-type", func(t *testing.T) { request, err := http.NewRequest(http.MethodPost, serverEndpoint, nil)