Skip to content

Commit

Permalink
Merge pull request #18 from Mehrn0ush/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Mehrn0ush authored Oct 22, 2024
2 parents bd757eb + 4601de6 commit 8a6e983
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/feature-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- name: Run cargo test
env:
REDIS_URL: redis://localhost:6379
JWT_SECRET: test_secret
run: cargo test

format:
Expand Down
12 changes: 6 additions & 6 deletions src/auth/rbac.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,7 @@ mod tests {

#[test]
fn test_extract_roles_success() {
dotenv::dotenv().ok(); // Ensure .env is loaded for consistency

// Ensure JWT_SECRET is set in the environment
let jwt_secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set in .env");
set_jwt_secret("test_secret"); // Set the secret directly for the test

// Define the claims
let claims = TestClaims {
Expand All @@ -285,8 +282,8 @@ mod tests {
roles: vec!["admin".to_string(), "user".to_string()],
};

// Generate the test token using the JWT_SECRET from environment
let token = generate_test_token(claims, &jwt_secret);
// Generate the test token using the same secret
let token = generate_test_token(claims, "test_secret");

// Log the generated token for debugging
println!("Generated token: {}", token);
Expand All @@ -308,5 +305,8 @@ mod tests {
let roles = result.unwrap();
assert!(roles.contains(&"admin".to_string()));
assert!(roles.contains(&"user".to_string()));

// Clean up
remove_jwt_secret();
}
}
72 changes: 72 additions & 0 deletions src/core/extension_grants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,75 @@ impl ExtensionGrantHandler for CustomGrant {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;

#[test]
fn test_generate_device_code() {
let handler = DefaultDeviceFlowHandler::new("https://example.com");
let response = handler.generate_device_code();

assert_eq!(response.device_code, "generated_device_code");
assert_eq!(response.verification_uri, "https://example.com/device");
assert!(response.expires_in > 0);
assert_eq!(response.interval, 5);
}

#[test]
fn test_poll_valid_device_code() {
let handler = DefaultDeviceFlowHandler::new("https://example.com");
let result = handler.poll_device_code("valid_device_code");

assert!(result.is_ok());
let token_response = result.unwrap();
assert_eq!(token_response.access_token, "device_access_token");
assert_eq!(token_response.token_type, "Bearer");
assert_eq!(token_response.expires_in, 3600);
assert_eq!(token_response.refresh_token, "device_refresh_token");
assert_eq!(token_response.scope.unwrap(), "read write");
}

#[test]
fn test_poll_invalid_device_code() {
let handler = DefaultDeviceFlowHandler::new("https://example.com");
let result = handler.poll_device_code("invalid_device_code");

assert!(result.is_err());
let error = result.unwrap_err();
assert_eq!(error, TokenError::InvalidGrant);
}

#[test]
fn test_handle_custom_grant() {
let handler = CustomGrant;
let mut params = HashMap::new();
params.insert("custom_param".to_string(), "value".to_string());

let result = handler
.handle_extension_grant("urn:ietf:params:oauth:grant-type:custom-grant", &params);
assert!(result.is_ok());

let token_response = result.unwrap();
assert_eq!(token_response.access_token, "custom_access_token");
assert_eq!(token_response.token_type, "Bearer");
assert_eq!(token_response.expires_in, 3600);
assert_eq!(token_response.refresh_token, "custom_refresh_token");
assert_eq!(token_response.scope.unwrap(), "custom_scope");
}

#[test]
fn test_handle_unsupported_grant() {
let handler = CustomGrant;
let mut params = HashMap::new();
params.insert("unsupported_param".to_string(), "value".to_string());

let result = handler.handle_extension_grant("unsupported_grant_type", &params);
assert!(result.is_err());

let error = result.unwrap_err();
assert_eq!(error, TokenError::UnsupportedGrantType);
}
}
75 changes: 72 additions & 3 deletions src/core/grants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,89 @@ impl ExtensionGrantHandler for CustomGrant {
// Add custom logic here
println!("Handling custom grant type");

//Validate required parameters from `params`
// Validate required parameters from `params`

// Return a token response, maybe after verifying params
// Return a token response after verifying params
Ok(TokenResponse {
access_token: "custom_access_token".to_string(),
token_type: "Bearer".to_string(),
expires_in: 3600,
refresh_token: Some("custom_refresh_token".to_string()),
refresh_token: "custom_refresh_token".to_string(), // Corrected type
scope: None, // Assuming no scope is passed, update if necessary
})
} else {
Err(TokenError::UnsupportedGrantType)
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;

#[test]
fn test_custom_grant_success() {
let grant_handler = CustomGrant;

// Define a HashMap with expected parameters
let mut params = HashMap::new();
params.insert("some_param".to_string(), "value".to_string());

// Call the handle_extension_grant function with the correct grant type
let result = grant_handler
.handle_extension_grant("urn:ietf:params:oauth:grant-type:custom-grant", &params);

// Ensure the result is OK and contains expected token data
assert!(result.is_ok());

// Unwrap the result to get the TokenResponse
let token_response = result.unwrap();

// Check that the token response contains the expected values
assert_eq!(token_response.access_token, "custom_access_token");
assert_eq!(token_response.token_type, "Bearer");
assert_eq!(token_response.expires_in, 3600);
assert_eq!(token_response.refresh_token, "custom_refresh_token");
assert!(token_response.scope.is_none()); // Expecting no scope
}

#[test]
fn test_custom_grant_unsupported_grant_type() {
let grant_handler = CustomGrant;

// Define a HashMap with expected parameters
let mut params = HashMap::new();
params.insert("some_param".to_string(), "value".to_string());

// Call the handle_extension_grant function with an unsupported grant type
let result = grant_handler.handle_extension_grant("unsupported-grant", &params);

// Ensure the result is an Err with TokenError::UnsupportedGrantType
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TokenError::UnsupportedGrantType);
}

#[test]
fn test_custom_grant_missing_params() {
let grant_handler = CustomGrant;

// Call the handle_extension_grant function with empty parameters
let params = HashMap::new();
let result = grant_handler
.handle_extension_grant("urn:ietf:params:oauth:grant-type:custom-grant", &params);

// Ensure the result is OK (you can modify this to handle missing params logic)
assert!(result.is_ok());

// Unwrap the result to check token response
let token_response = result.unwrap();

// Check that the token response contains the expected values
assert_eq!(token_response.access_token, "custom_access_token");
assert_eq!(token_response.token_type, "Bearer");
assert_eq!(token_response.expires_in, 3600);
assert_eq!(token_response.refresh_token, "custom_refresh_token");
assert!(token_response.scope.is_none());
}
}
1 change: 1 addition & 0 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod authorization;
pub mod client_credentials;
pub mod device_flow;
pub mod extension_grants;
pub mod grants;
pub mod oidc_providers;
pub mod pkce;
pub mod refresh;
Expand Down
61 changes: 55 additions & 6 deletions src/core/pkce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use thiserror::Error; // For better error handling
const MIN_VERIFIER_LENGTH: usize = 43;
const MAX_VERIFIER_LENGTH: usize = 128;

#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum PkceError {
#[error("Verifier is too short or too long")]
InvalidVerifierLength,
Expand All @@ -19,11 +19,6 @@ pub enum PkceError {

// Function to validate if a PKCE verifier meets the length and character requirements
fn validate_verifier(verifier: &str) -> Result<(), PkceError> {
// Check that the verifier length is within the allowed range (43-128 characters)
if verifier.len() < MIN_VERIFIER_LENGTH || verifier.len() > MAX_VERIFIER_LENGTH {
return Err(PkceError::InvalidVerifierLength);
}

// Check that the verifier contains only valid characters (alphanumeric and "-._~")
if !verifier
.chars()
Expand All @@ -32,6 +27,11 @@ fn validate_verifier(verifier: &str) -> Result<(), PkceError> {
return Err(PkceError::InvalidVerifierCharacters);
}

// Check that the verifier length is within the allowed range (43-128 characters)
if verifier.len() < MIN_VERIFIER_LENGTH || verifier.len() > MAX_VERIFIER_LENGTH {
return Err(PkceError::InvalidVerifierLength);
}

Ok(())
}

Expand Down Expand Up @@ -62,3 +62,52 @@ pub fn validate_pkce_challenge(challenge: &str, verifier: &str) -> Result<(), Pk
Err(PkceError::InvalidVerifier)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_valid_pkce_verifier() {
let verifier = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~";
let challenge = generate_pkce_challenge(verifier);
assert!(challenge.is_ok());
}

#[test]
fn test_invalid_pkce_verifier_length() {
let short_verifier = "short";
let long_verifier = "a".repeat(200); // Over the max length

let short_result = generate_pkce_challenge(&short_verifier);
let long_result = generate_pkce_challenge(&long_verifier);

assert_eq!(short_result.unwrap_err(), PkceError::InvalidVerifierLength);
assert_eq!(long_result.unwrap_err(), PkceError::InvalidVerifierLength);
}

#[test]
fn test_invalid_pkce_verifier_characters() {
let invalid_verifier = "invalid@chars!";
let result = generate_pkce_challenge(invalid_verifier);
assert_eq!(result.unwrap_err(), PkceError::InvalidVerifierCharacters);
}

#[test]
fn test_valid_pkce_challenge_validation() {
let verifier = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~";
let challenge = generate_pkce_challenge(verifier).unwrap();

let result = validate_pkce_challenge(&challenge, verifier);
assert!(result.is_ok());
}

#[test]
fn test_invalid_pkce_challenge_validation() {
let verifier = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~";
let challenge = generate_pkce_challenge(verifier).unwrap();

let result = validate_pkce_challenge("incorrect_challenge", verifier);
assert_eq!(result.unwrap_err(), PkceError::InvalidVerifier);
}
}

0 comments on commit 8a6e983

Please sign in to comment.