Skip to content

Commit

Permalink
removed unecessary comments, added tests, extracted some duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed May 28, 2024
1 parent 98efbce commit 4430c61
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 74 deletions.
11 changes: 4 additions & 7 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,11 @@ func (s *PorchServer) Run(ctx context.Context) error {

certStorageDir, found := os.LookupEnv("CERT_STORAGE_DIR")
if found && strings.TrimSpace(certStorageDir) != "" {
// check if we should use cert manager webhook setup
useCertManWebhook := false
// if we dont find the env var or the var is not true
CertManWebhook, found := os.LookupEnv("CERT_MAN_WEBHOOKS")
if found || CertManWebhook == "true"{
useCertManWebhook = true
useCertMan := false
if _, found := os.LookupEnv("USE_CERT_MAN_FOR_WEBHOOK"); found {
useCertMan = true
}
if err := setupWebhooks(ctx, webhookNs, certStorageDir, useCertManWebhook); err != nil {
if err := setupWebhooks(ctx, webhookNs, certStorageDir, useCertMan); err != nil {
klog.Errorf("%v\n", err)
return err
}
Expand Down
94 changes: 27 additions & 67 deletions pkg/apiserver/webhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"net/http"
"os"
"path/filepath"
"sync"

"time"

"github.com/fsnotify/fsnotify"
Expand All @@ -55,7 +55,6 @@ const (
)

var (
mu sync.Mutex
cert tls.Certificate
certModTime time.Time
)
Expand All @@ -69,7 +68,7 @@ func setupWebhooks(ctx context.Context, webhookNs string, certStorageDir string,
if err := createValidatingWebhook(ctx, webhookNs, caBytes); err != nil {
return err
}

}
if err := runWebhookServer(certStorageDir, useCertManWebhook); err != nil {
return err
Expand Down Expand Up @@ -238,19 +237,13 @@ func createValidatingWebhook(ctx context.Context, webhookNs string, caCert []byt
return nil
}

// load the certificate from the secret and update when secret cert data changes e.g.
// load the certificate & keep note of time loaded for reload on new cert details
func loadCertificate(certPath, keyPath string) (tls.Certificate, error) {
// get info about cert manager certificate secret mounted as a volume on the porch server pod
certInfo, err := os.Stat(certPath)
if err != nil {
return tls.Certificate{}, err
}
// if the last time this secret was modified was after the last time we loaded its cert files
// we lock access to the mount path and load the keypair from the files in our tls then release the lock
// set this new loaded cert as our current cert and note the modtime for next reload
if certInfo.ModTime().After(certModTime) {
mu.Lock()
defer mu.Unlock()
newCert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return tls.Certificate{}, err
Expand All @@ -269,7 +262,11 @@ func watchCertificates(directory, certFile, keyFile string) {
klog.Fatalf("failed to start certificate watcher: %v", err)
}
defer watcher.Close()

// start the watcher with the dir to watch
err = watcher.Add(directory)
if err != nil {
klog.Errorf("Error in running watcher: %v", err)
}
// if the watcher notices any changes on the mount dir of the secret such as creations or writes to the files in this dir
// attempt to load tls cert using the new cert and key files provided and log output
done := make(chan bool)
Expand All @@ -296,72 +293,35 @@ func watchCertificates(directory, certFile, keyFile string) {
}
}
}()
// start the watcher with the dir to watch
err = watcher.Add(directory)
if err != nil {
klog.Fatalf("Error in running watcher: %v", err)
}

<-done
}
// return current cert
func getCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
mu.Lock()
defer mu.Unlock()
return &cert, nil
}

func runWebhookServer(certStorageDir string, useCertManWebhook bool) error {
certFile := filepath.Join(certStorageDir, "tls.crt")
keyFile := filepath.Join(certStorageDir, "tls.key")
// load the cert for the first time

cert, err := loadCertificate(certFile, keyFile)
if err != nil {
klog.Errorf("failed to load certificate: %v", err)
return err
}
if useCertManWebhook {
_, err := loadCertificate(certFile, keyFile)
if err != nil {
klog.Errorf("failed to load certificate: %v", err)
}
// Start watching the certificate files for changes
// watch for changes in directory where secret is mounted
go watchCertificates(certStorageDir, certFile, keyFile)

klog.Infoln("Starting webhook server")
http.HandleFunc(serverEndpoint, validateDeletion)
server := http.Server{
Addr: fmt.Sprintf(":%d", webhookServicePort),
TLSConfig: &tls.Config{
GetCertificate: getCertificate,
},
}
go func() {
err = server.ListenAndServeTLS("", "")
if err != nil {
klog.Errorf("could not start server: %v", err)
}
}()
return err

} else {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
}
klog.Infoln("Starting webhook server")
http.HandleFunc(serverEndpoint, validateDeletion)
server := http.Server{
Addr: fmt.Sprintf(":%d", webhookServicePort),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
}
go func() {
err = server.ListenAndServeTLS("", "")
if err != nil {
return err
}
klog.Infoln("Starting webhook server")
http.HandleFunc(serverEndpoint, validateDeletion)
server := http.Server{
Addr: fmt.Sprintf(":%d", webhookServicePort),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
klog.Errorf("could not start server: %v", err)
}
go func() {
err = server.ListenAndServeTLS("", "")
if err != nil {
klog.Errorf("could not start server: %v", err)
}
}()
return err
}
}()
return err
}

func validateDeletion(w http.ResponseWriter, r *http.Request) {
Expand Down
126 changes: 126 additions & 0 deletions pkg/apiserver/webhooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ package apiserver
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
admissionv1 "k8s.io/api/admission/v1"
Expand Down Expand Up @@ -58,6 +60,130 @@ func TestCreateCerts(t *testing.T) {
require.True(t, strings.HasSuffix(keyStr, "\n-----END RSA PRIVATE KEY-----"))
}

func TestLoadCertificate(t *testing.T) {

//what do i need to test.
// first create dummy certs for testing
dir := t.TempDir()
defer func() {
require.NoError(t, os.RemoveAll(dir))
}()
_, err := createCerts("", dir)
require.NoError(t, err)

//1. test file that cannot be os.stated or causes that function to fail
_, err1 := loadCertificate(filepath.Join(dir, "nonexistingcrtfile.key"), filepath.Join(dir, "nonexistingkeyfile.key"))
require.Error(t, err1)
// reseting back to 0
certModTime = time.Time{}
//2. test happy path of os.stat and continue to next error
//3. test loading good cert happy path
keypath := filepath.Join(dir, "tls.key")
crtpath := filepath.Join(dir, "tls.crt")

_, err2 := loadCertificate(crtpath, keypath)
require.NoError(t, err2)
certModTime = time.Time{}

//4. test loading faulty cert error
data := []byte("Hello, World!")

writeErr := os.WriteFile(keypath, data, 0644)
require.NoError(t, writeErr)
writeErr2 := os.WriteFile(crtpath, data, 0644)
require.NoError(t, writeErr2)

_, err3 := loadCertificate(filepath.Join(dir, "tls.crt"), filepath.Join(dir, "tls.key"))
require.Error(t, err3)
certModTime = time.Time{}
}

func captureStderr(f func()) string {
r, w, _ := os.Pipe()
stderr := os.Stderr
os.Stderr = w
outC := make(chan string)

// Copy the output in a separate goroutine so printing can't block indefinitely.
go func() {
var buf bytes.Buffer
io.Copy(&buf, r)
outC <- buf.String()
}()

// Run the provided function and capture the stderr output
f()

// Restore the original stderr and close the write-end of the pipe so the goroutine will exit
os.Stderr = stderr
w.Close()
out := <-outC

return out
}

func TestWatchCertificates(t *testing.T) {
// method for processing klog output
assertLogMessages := func(log string) error {
if len(log) > 0 {
if log[0] == 'E' || log[0] == 'W' || log[0] == 'F' {
return errors.New("Error Occured in Watcher")
}
}
return nil
}
// Set up the temp directory with dummy certificate files
dir := t.TempDir()
defer func() {
require.NoError(t, os.RemoveAll(dir))
}()
_, err := createCerts("", dir)
require.NoError(t, err)

keyFile := filepath.Join(dir, "tls.key")
certFile := filepath.Join(dir, "tls.crt")

// firstly test error occuring from invalid entity for watcher to watch. aka invalid dir
// we expect an error
go watchCertificates("Dummy Directory that does not exist", certFile, keyFile)

invalid_watch_entity_logs := captureStderr(func() {
time.Sleep(100 * time.Millisecond) // Give some time for the logs to be flushed
})
t.Log(invalid_watch_entity_logs)
err = assertLogMessages(invalid_watch_entity_logs)
require.Error(t, err)

go watchCertificates(dir, certFile, keyFile)
time.Sleep(1 * time.Second)

//create file to trigger change but not alter the certificate contents
//should trigger reload and certificate reloaded successfully
newFilePath := filepath.Join(dir, "new_temp_file.txt")
_, err = os.Create(newFilePath)
require.NoError(t, err)

valid_reload_logs := captureStderr(func() {
time.Sleep(100 * time.Millisecond) // Give some time for the logs to be flushed
})
t.Log(valid_reload_logs)
err = assertLogMessages(valid_reload_logs)
require.NoError(t, err)

// Modify the certificate file to trigger a file system event
// should cause an error log since cert contents are not valid anymore
certModTime = time.Time{}
err = os.WriteFile(certFile, []byte("dummy text"), 0660)
require.NoError(t, err)

invalid_reload_logs := captureStderr(func() {
time.Sleep(100 * time.Millisecond) // Give some time for the logs to be flushed
})
t.Log(invalid_reload_logs)
err = assertLogMessages(invalid_reload_logs)
require.Error(t, err)
}

func TestValidateDeletion(t *testing.T) {
t.Run("invalid content-type", func(t *testing.T) {
request, err := http.NewRequest(http.MethodPost, serverEndpoint, nil)
Expand Down

0 comments on commit 4430c61

Please sign in to comment.