From 5611171b1a8dbfdd62b2432dc27c55533c7a7ea0 Mon Sep 17 00:00:00 2001 From: bkawk Date: Sun, 5 Feb 2023 01:02:25 +0000 Subject: [PATCH] utils tests --- api/utils/generateJwt.go | 11 ++- api/utils/refreshToken.go | 9 +- api/utils/utils_test.go | 182 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 api/utils/utils_test.go diff --git a/api/utils/generateJwt.go b/api/utils/generateJwt.go index 4309016..4557237 100644 --- a/api/utils/generateJwt.go +++ b/api/utils/generateJwt.go @@ -1,6 +1,7 @@ package utils import ( + "errors" "os" "time" @@ -8,6 +9,14 @@ import ( ) func GenerateJWT(id string) (string, error) { + secret := os.Getenv("JWT_SECRET") + if secret == "" { + return "", errors.New("JWT_SECRET environment variable is not set") + } + if id == "" { + return "", errors.New("ID cannot be empty") + } + // Create a new JWT token token := jwt.New(jwt.SigningMethodHS256) @@ -17,7 +26,7 @@ func GenerateJWT(id string) (string, error) { claims["exp"] = time.Now().Add(time.Hour * 72).Unix() // Generate encoded token and send it as response. - t, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) + t, err := token.SignedString([]byte(secret)) if err != nil { return "", err } diff --git a/api/utils/refreshToken.go b/api/utils/refreshToken.go index c43ce83..b16877b 100644 --- a/api/utils/refreshToken.go +++ b/api/utils/refreshToken.go @@ -14,13 +14,18 @@ const ( // GenerateRefreshToken generates a new refresh token func GenerateRefreshToken() (string, error) { - b := make([]byte, RefreshTokenLength) + b := make([]byte, (RefreshTokenLength*3)/4) _, err := rand.Read(b) if err != nil { return "", err } - return base64.URLEncoding.EncodeToString(b), nil + encoded := base64.URLEncoding.EncodeToString(b) + if len(encoded) > RefreshTokenLength { + encoded = encoded[:RefreshTokenLength] + } + + return encoded, nil } // ValidateRefreshToken checks if a refresh token is valid and has not expired diff --git a/api/utils/utils_test.go b/api/utils/utils_test.go new file mode 100644 index 0000000..e00a4a9 --- /dev/null +++ b/api/utils/utils_test.go @@ -0,0 +1,182 @@ +package utils + +import ( + "fmt" + "os" + "testing" +) + +func TestGenerateJWT_Success(t *testing.T) { + // Setup + id := "user123" + os.Setenv("JWT_SECRET", "secret_key") + + // Test + token, err := GenerateJWT(id) + + // Assertions + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if token == "" { + t.Errorf("Expected token, got empty string") + } +} + +func TestGenerateJWT_MissingSecret(t *testing.T) { + // Setup + id := "user123" + os.Unsetenv("JWT_SECRET") + + // Test + _, err := GenerateJWT(id) + + // Assertions + if err == nil { + t.Errorf("Expected error, got nil") + } +} + +func TestGenerateJWT_EmptyID(t *testing.T) { + // Setup + id := "" + os.Setenv("JWT_SECRET", "secret_key") + + // Test + _, err := GenerateJWT(id) + + // Assertions + if err == nil { + t.Errorf("Expected error, got nil") + } +} + +func TestValidatePasswordShortPassword(t *testing.T) { + err := ValidatePassword("pass") + expected := fmt.Errorf("password must be at least 8 characters long") + + if err == nil || err.Error() != expected.Error() { + t.Errorf("ValidatePassword(\"pass\") = %v; want %v", err, expected) + } +} +func TestValidatePasswordNoUppercase(t *testing.T) { + err := ValidatePassword("password") + expected := fmt.Errorf("password must contain at least one uppercase letter") + + if err == nil || err.Error() != expected.Error() { + t.Errorf("ValidatePassword(\"password\") = %v; want %v", err, expected) + } +} + +func TestValidatePasswordNoLowercase(t *testing.T) { + err := ValidatePassword("PASSWORD") + expected := fmt.Errorf("password must contain at least one lowercase letter") + + if err == nil || err.Error() != expected.Error() { + t.Errorf("ValidatePassword(\"PASSWORD\") = %v; want %v", err, expected) + } +} + +func TestValidatePasswordNoDigit(t *testing.T) { + err := ValidatePassword("Password") + expected := fmt.Errorf("password must contain at least one digit") + + if err == nil || err.Error() != expected.Error() { + t.Errorf("ValidatePassword(\"Password\") = %v; want %v", err, expected) + } +} + +func TestValidatePasswordNoSpecialCharacter(t *testing.T) { + err := ValidatePassword("Password1") + expected := fmt.Errorf("password must contain at least one special character") + + if err == nil || err.Error() != expected.Error() { + t.Errorf("ValidatePassword(\"Password1\") = %v; want %v", err, expected) + } +} + +func TestValidatePasswordEmptyPassword(t *testing.T) { + err := ValidatePassword("") + expected := fmt.Errorf("password must be at least 8 characters long") + + if err == nil || err.Error() != expected.Error() { + t.Errorf("ValidatePassword(\"\") = %v; want %v", err, expected) + } +} + +func TestGenerateRefreshToken(t *testing.T) { + token, err := GenerateRefreshToken() + if err != nil { + t.Errorf("GenerateRefreshToken returned error: %v", err) + } + + if len(token) == 0 { + t.Error("GenerateRefreshToken returned an empty string") + } +} + +func TestGenerateRefreshTokenSuccess(t *testing.T) { + token, err := GenerateRefreshToken() + if err != nil { + t.Errorf("GenerateRefreshToken returned error: %v", err) + } + + if len(token) != RefreshTokenLength { + t.Errorf("GenerateRefreshToken returned a token of length %d, expected length %d", len(token), RefreshTokenLength) + } +} + +func TestGenerateRefreshTokenLength(t *testing.T) { + token, err := GenerateRefreshToken() + if err != nil { + t.Errorf("GenerateRefreshToken returned error: %v", err) + } + + if len(token) != RefreshTokenLength { + t.Errorf("GenerateRefreshToken returned a token of length %d, expected length %d", len(token), RefreshTokenLength) + } +} + +func TestGenerateRefreshTokenUniqueness(t *testing.T) { + tokens := make(map[string]bool) + + for i := 0; i < 100; i++ { + token, err := GenerateRefreshToken() + if err != nil { + t.Errorf("GenerateRefreshToken returned error: %v", err) + } + + if _, exists := tokens[token]; exists { + t.Errorf("GenerateRefreshToken returned a duplicate token: %s", token) + break + } + + tokens[token] = true + } +} + +func TestGenerateUUID(t *testing.T) { + ids := make(map[string]bool) + + for i := 0; i < 10000; i++ { + id := GenerateUUID() + + if _, exists := ids[id]; exists { + t.Errorf("GenerateUUID returned a duplicate identifier: %s", id) + break + } + + ids[id] = true + } +} + +func TestGenerateUUIDLength(t *testing.T) { + expectedLength := 36 + for i := 0; i < 100; i++ { + id := GenerateUUID() + if len(id) != expectedLength { + t.Errorf("GenerateUUID returned an unexpected length identifier: got %d, expected %d", len(id), expectedLength) + break + } + } +}