Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flatten (de)serialization of custom user claims #1159

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions examples/demo/src/controllers/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async fn register(
let jwt_secret = ctx.config.get_jwt_config()?;

let token = user
.generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
.generate_jwt(&jwt_secret.secret, jwt_secret.expiration)
.or_else(|_| unauthorized("unauthorized!"))?;
format::json(UserSession::new(&user, &token))
}
Expand Down Expand Up @@ -130,7 +130,7 @@ async fn login(State(ctx): State<AppContext>, Json(params): Json<LoginParams>) -
let jwt_secret = ctx.config.get_jwt_config()?;

let token = user
.generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
.generate_jwt(&jwt_secret.secret, jwt_secret.expiration)
.or_else(|_| unauthorized("unauthorized!"))?;

format::json(UserSession::new(&user, &token))
Expand Down
13 changes: 6 additions & 7 deletions examples/demo/src/models/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use async_trait::async_trait;
use chrono::offset::Local;
use loco_rs::{auth::jwt, hash, prelude::*};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Map;
use uuid::Uuid;

pub use super::_entities::users::{self, ActiveModel, Entity, Model};
Expand Down Expand Up @@ -216,12 +216,11 @@ impl super::_entities::users::Model {
/// # Errors
///
/// when could not convert user claims to jwt token
pub fn generate_jwt(&self, secret: &str, expiration: &u64) -> ModelResult<String> {
Ok(jwt::JWT::new(secret).generate_token(
expiration,
self.pid.to_string(),
Some(json!({"Roll": "Administrator"})),
)?)
pub fn generate_jwt(&self, secret: &str, expiration: u64) -> ModelResult<String> {
let mut claims = Map::new();
claims.insert("Role".to_string(), "Administrator".into());
```?
Ok(jwt::JWT::new(secret).generate_token(expiration, self.pid.to_string(), claims)?)
}
}

Expand Down
21 changes: 9 additions & 12 deletions loco-new/base_template/src/controllers/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ async fn register(
/// Verify register user. if the user not verified his email, he can't login to
/// the system.
#[debug_handler]
async fn verify(
State(ctx): State<AppContext>,
Path(token): Path<String>,
) -> Result<Response> {
async fn verify(State(ctx): State<AppContext>, Path(token): Path<String>) -> Result<Response> {
let user = users::Model::find_by_verification_token(&ctx.db, &token).await?;

if user.email_verified_at.is_some() {
Expand Down Expand Up @@ -143,7 +140,7 @@ async fn login(State(ctx): State<AppContext>, Json(params): Json<LoginParams>) -
let jwt_secret = ctx.config.get_jwt_config()?;

let token = user
.generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
.generate_jwt(&jwt_secret.secret, jwt_secret.expiration)
.or_else(|_| unauthorized("unauthorized!"))?;

format::json(LoginResponse::new(&user, &token))
Expand All @@ -158,14 +155,14 @@ async fn current(auth: auth::JWT, State(ctx): State<AppContext>) -> Result<Respo
/// Magic link authentication provides a secure and passwordless way to log in to the application.
///
/// # Flow
/// 1. **Request a Magic Link**:
/// A registered user sends a POST request to `/magic-link` with their email.
/// If the email exists, a short-lived, one-time-use token is generated and sent to the user's email.
/// 1. **Request a Magic Link**:
/// A registered user sends a POST request to `/magic-link` with their email.
/// If the email exists, a short-lived, one-time-use token is generated and sent to the user's email.
/// For security and to avoid exposing whether an email exists, the response always returns 200, even if the email is invalid.
///
/// 2. **Click the Magic Link**:
/// The user clicks the link (/magic-link/{token}), which validates the token and its expiration.
/// If valid, the server generates a JWT and responds with a [`LoginResponse`].
/// 2. **Click the Magic Link**:
/// The user clicks the link (/magic-link/{token}), which validates the token and its expiration.
/// If valid, the server generates a JWT and responds with a [`LoginResponse`].
/// If invalid or expired, an unauthorized response is returned.
///
/// This flow enhances security by avoiding traditional passwords and providing a seamless login experience.
Expand Down Expand Up @@ -211,7 +208,7 @@ async fn magic_link_verify(
let jwt_secret = ctx.config.get_jwt_config()?;

let token = user
.generate_jwt(&jwt_secret.secret, &jwt_secret.expiration)
.generate_jwt(&jwt_secret.secret, jwt_secret.expiration)
.or_else(|_| unauthorized("unauthorized!"))?;

format::json(LoginResponse::new(&user, &token))
Expand Down
5 changes: 3 additions & 2 deletions loco-new/base_template/src/models/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use async_trait::async_trait;
use chrono::{offset::Local, Duration};
use loco_rs::{auth::jwt, hash, prelude::*};
use serde::{Deserialize, Serialize};
use serde_json::Map;
use uuid::Uuid;

pub use super::_entities::users::{self, ActiveModel, Entity, Model};
Expand Down Expand Up @@ -258,8 +259,8 @@ impl Model {
/// # Errors
///
/// when could not convert user claims to jwt token
pub fn generate_jwt(&self, secret: &str, expiration: &u64) -> ModelResult<String> {
Ok(jwt::JWT::new(secret).generate_token(expiration, self.pid.to_string(), None)?)
pub fn generate_jwt(&self, secret: &str, expiration: u64) -> ModelResult<String> {
Ok(jwt::JWT::new(secret).generate_token(expiration, self.pid.to_string(), Map::new())?)
}
}

Expand Down
117 changes: 102 additions & 15 deletions src/auth/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@
//!
//! This module provides functionality for working with JSON Web Tokens (JWTs)
//! and password hashing.

use jsonwebtoken::{
decode, encode, errors::Result as JWTResult, get_current_timestamp, Algorithm, DecodingKey,
EncodingKey, Header, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{Map, Value};

/// Represents the default JWT algorithm used by the [`JWT`] struct.
const JWT_ALGORITHM: Algorithm = Algorithm::HS512;

/// Represents the claims associated with a user JWT.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added Eq and PartialEq so we can use assert_eq! with this struct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like Eq and PartialEq are only needed for testing purposes. To address this, you can use #[cfg_attr(test, derive(Eq, PartialEq))] to conditionally derive them only in test builds.

pub struct UserClaims {
pub pid: String,
exp: u64,
pub claims: Option<Value>,
#[serde(default, flatten)]
// TODO: should we wrap this in an Option? `Option<Map<String, Value>>`
// so we can use `auth::jwt::JWT::new("PqRwLF2rhHe8J22oBeHy").generate_token(&604800, "PID".to_string(), None);
// TODO: serde_json::Map or std::collections::HashMap?
// TODO: is it ok to use a generic Map<String, Value> here? Or should we let the user specify their desired typed claim and
Comment on lines +21 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please remove the TODOs?

// use generics to serialize/deserialize it?
pub claims: Map<String, Value>,
Copy link
Contributor Author

@jorgehermo9 jorgehermo9 Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serde_json::Map? HashMap? BtreeMap?

The advantage of serde_json::Map is that users can use the serde_json's preserve_order flag to control how the claims are serialized https://docs.rs/serde_json/latest/serde_json/enum.Value.html#variant.Object

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serde_json::Map for sure! Didn't know that existed.

}

/// Represents the JWT configuration and operations.
Expand Down Expand Up @@ -61,17 +66,18 @@ impl JWT {
///
/// # Example
/// ```rust
/// use serde_json::Map;
Copy link
Contributor Author

@jorgehermo9 jorgehermo9 Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I hide this import in the doctest or not?

Suggested change
/// use serde_json::Map;
/// # use serde_json::Map;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't hide it. I think it makes clear to the user what to import to use this feature.

/// use loco_rs::auth;
///
/// auth::jwt::JWT::new("PqRwLF2rhHe8J22oBeHy").generate_token(&604800, "PID".to_string(), None);
/// auth::jwt::JWT::new("PqRwLF2rhHe8J22oBeHy").generate_token(604800, "PID".to_string(), Map::new());
/// ```
pub fn generate_token(
&self,
expiration: &u64,
expiration: u64,
Copy link
Contributor Author

@jorgehermo9 jorgehermo9 Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u64 implements Copy (https://doc.rust-lang.org/std/marker/trait.Copy.html) so there is no need to use references here, we can simply copy it.

This may be a bit out of scope, I can address this in another PR if required, but seemed to me too little change and not worth of a separate PR

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only thing I can say is that it won't be a biggie for me :)

pid: String,
claims: Option<Value>,
claims: Map<String, Value>,
) -> JWTResult<String> {
let exp = get_current_timestamp().saturating_add(*expiration);
let exp = get_current_timestamp().saturating_add(expiration);

let claims = UserClaims { pid, exp, claims };

Expand Down Expand Up @@ -119,18 +125,27 @@ mod tests {
use super::*;

#[rstest]
#[case("valid token", 60, None)]
#[case("token expired", 1, None)]
#[case("valid token and custom claims", 60, Some(json!({})))]
#[tokio::test]
async fn can_generate_token(
#[case("valid token", 60, json!({}))]
#[case("token expired", 1, json!({}))]
#[case("valid token and custom string claims", 60, json!({ "custom": "claim",}))]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are those cases enough? too much? any specific case that you miss?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this is more than enough. You could have a second look at the link below but it doesn't dictate much about the specific form of claims in a JWT.

JWT Spec

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, any JSON value should be valid as the key's value :)

#[case("valid token and custom boolean claims",60, json!({ "custom": true,}))]
#[case("valid token and custom number claims",60, json!({ "custom": 123,}))]
#[case("valid token and custom nested claims",60, json!({ "level1": { "level2": { "level3": "claim" } } }))]
#[case("valid token and custom array claims",60, json!({ "array": [1, 2, 3] }))]
#[case("valid token and custom nested array claims",60, json!({ "level1": { "level2": { "level3": [1, 2, 3] } } }))]
fn can_generate_token(
#[case] test_name: &str,
#[case] expiration: u64,
#[case] claims: Option<Value>,
#[case] json_claims: Value,
) {
let claims = json_claims
.as_object()
.expect("case input claims must be an object")
.clone();
let jwt = JWT::new("PqRwLF2rhHe8J22oBeHy");

let token = jwt
.generate_token(&expiration, "pid".to_string(), claims)
.generate_token(expiration, "pid".to_string(), claims)
.unwrap();

std::thread::sleep(std::time::Duration::from_secs(3));
Expand All @@ -140,4 +155,76 @@ mod tests {
assert_debug_snapshot!(test_name, jwt.validate(&token));
});
}

#[rstest]
#[case::without_custom_claims(json!({}))]
#[case::with_custom_string_claims(json!({ "custom": "claim",}))]
#[case::with_custom_boolean_claims(json!({ "custom": true,}))]
#[case::with_custom_number_claims(json!({ "custom": 123,}))]
#[case::with_custom_nested_claims(json!({ "level1": { "level2": { "level3": "claim" } } }))]
#[case::with_custom_array_claims(json!({ "array": [1, 2, 3] }))]
#[case::with_custom_nested_array_claims(json!({ "level1": { "level2": { "level3": [1, 2, 3] } } }))]
// we use `Value` to reduce code duplicity in the case inputs
fn serialize_user_claims(#[case] json_claims: Value) {
let claims = json_claims
.as_object()
.expect("case input claims must be an object")
.clone();
let input_user_claims = UserClaims {
pid: "pid".to_string(),
exp: 60,
claims: claims.clone(),
};

let mut expected_claim = Map::new();
expected_claim.insert("pid".to_string(), "pid".into());
expected_claim.insert("exp".to_string(), 60.into());
// we add the claims in a flattened way
expected_claim.extend(claims);
let expected_value = Value::from(expected_claim);

// We check between `Value` instead of `String` to avoid key ordering issues when serializing.
// It is because `expected_value` has all the keys in alphabetical order, as the `Value` serialization ensures that.
// But when serializing `input_user_claims`, first the `pid` and `exp` fields are serialized (in that order),
// and then the claims are serialized in alfabetic order. So, the resulting JSON string from the `input_user_claims` serialization
// may have the `pid` and `exp` fields unordered which differs from the `Value` serialization.
assert_eq!(
expected_value,
serde_json::to_value(&input_user_claims).unwrap()
);
}

#[rstest]
#[case::without_custom_claims(json!({}))]
#[case::with_custom_string_claims(json!({ "custom": "claim",}))]
#[case::with_custom_boolean_claims(json!({ "custom": true,}))]
#[case::with_custom_number_claims(json!({ "custom": 123,}))]
#[case::with_custom_nested_claims(json!({ "level1": { "level2": { "level3": "claim" } } }))]
#[case::with_custom_array_claims(json!({ "array": [1, 2, 3] }))]
#[case::with_custom_nested_array_claims(json!({ "level1": { "level2": { "level3": [1, 2, 3] } } }))]
// we use `Value` to reduce code duplicity in the case inputs
fn deserialize_user_claims(#[case] json_claims: Value) {
let claims = json_claims
.as_object()
.expect("case input claims must be an object")
.clone();

let mut input_claims = Map::new();
input_claims.insert("pid".to_string(), "pid".into());
input_claims.insert("exp".to_string(), 60.into());
// we add the claims in a flattened way
input_claims.extend(claims.clone());
let input_json = Value::from(input_claims).to_string();

let expected_user_claims = UserClaims {
pid: "pid".to_string(),
exp: 60,
claims,
};

assert_eq!(
expected_user_claims,
serde_json::from_str(&input_json).unwrap()
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
---
source: src/auth/jwt.rs
expression: jwt.validate(&token)
---
Ok(
TokenData {
header: Header {
typ: Some(
"JWT",
),
alg: HS512,
cty: None,
jku: None,
jwk: None,
kid: None,
x5u: None,
x5c: None,
x5t: None,
x5t_s256: None,
},
claims: UserClaims {
pid: "pid",
exp: EXP,
claims: {
"array": Array [
Number(1),
Number(2),
Number(3),
],
},
},
},
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this snapshot file's diff is marked as "renamed file", but actually I deleted the old valid token and custom claims snapshot, as it was unreferenced, and added a new valid token and custom boolean claims snapshot, such as the other from this PR.

source: src/auth/jwt.rs
assertion_line: 133
expression: jwt.validate(&token)
---
Ok(
Expand All @@ -22,9 +21,9 @@ Ok(
claims: UserClaims {
pid: "pid",
exp: EXP,
claims: Some(
Object {},
),
claims: {
"custom": Bool(true),
},
},
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
---
source: src/auth/jwt.rs
expression: jwt.validate(&token)
---
Ok(
TokenData {
header: Header {
typ: Some(
"JWT",
),
alg: HS512,
cty: None,
jku: None,
jwk: None,
kid: None,
x5u: None,
x5c: None,
x5t: None,
x5t_s256: None,
},
claims: UserClaims {
pid: "pid",
exp: EXP,
claims: {
"level1": Object {
"level2": Object {
"level3": Array [
Number(1),
Number(2),
Number(3),
],
},
},
},
},
},
)
Loading
Loading