Skip to content

Commit

Permalink
Customize access token (#24)
Browse files Browse the repository at this point in the history
* add custom claims to access_token

* remove cargo-make

* lint

* test duplication

* comment removed

* remove useless complexity

* Update src/model/claims.rs

Co-authored-by: Simone Cottini <[email protected]>

* use plain text string instead of reading from file

* trailing space

Co-authored-by: Simone Cottini <[email protected]>
  • Loading branch information
dibericky and cottinisimone authored Oct 6, 2022
1 parent 70f5b69 commit 300b905
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 14 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions localauth0.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ permissions = ["audience1:permission1", "audience1:permission2"]
[[audience]]
name = "audience2"
permissions = ["audience2:permission2"]

[access_token]
custom_claims = [
{ name = "at_custom_claims_str", value = { String = "str" } }
]
35 changes: 33 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub struct Config {

#[serde(default)]
user: Vec<User>,

#[serde(default)]
access_token: AccessToken,
}

#[derive(Debug, Error)]
Expand Down Expand Up @@ -50,6 +53,7 @@ impl Config {
user_info: Default::default(),
audience: vec![],
user: vec![],
access_token: Default::default(),
}
}
}
Expand Down Expand Up @@ -109,13 +113,26 @@ impl Default for UserInfo {
}
}

#[derive(Debug, Deserialize, Getters)]
#[derive(Debug, Deserialize, Getters, Default)]
pub struct AccessToken {
#[serde(default)]
custom_claims: Vec<CustomField>,
}

#[derive(Debug, Deserialize, Getters, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub struct CustomField {
name: String,
value: CustomFieldValue,
}

#[derive(Debug, Deserialize)]
impl CustomField {
pub fn new(name: String, value: CustomFieldValue) -> Self {
Self { name, value }
}
}

#[derive(Debug, Deserialize, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub enum CustomFieldValue {
String(String),
Expand Down Expand Up @@ -171,6 +188,11 @@ mod tests {
[[audience]]
name = "audience2"
permissions = ["audience2:permission2"]
[access_token]
custom_claims = [
{ name = "at_custom_claim_str", value = { String = "str" } }
]
"#;

let config: Config = toml::from_str(config_str).unwrap();
Expand Down Expand Up @@ -207,5 +229,14 @@ mod tests {

let custom_field: &CustomField = custom_fields.iter().find(|v| v.name == "custom_field_str").unwrap();
assert_eq!(custom_field.value, CustomFieldValue::String("str".to_string()));

let access_token = config.access_token();
let at_custom_claims = access_token.custom_claims();

let at_custom_claim: &CustomField = at_custom_claims
.iter()
.find(|v| v.name == "at_custom_claim_str")
.unwrap();
assert_eq!(at_custom_claim.value, CustomFieldValue::String("str".to_string()))
}
}
68 changes: 68 additions & 0 deletions src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,13 @@ fn new_token_response(app_data: &AppData, audience: &str, grant_type: GrantType)
.get_permissions(audience)
.expect("Failed to get permissions");

let custom_claims = app_data.config().access_token().custom_claims().to_owned();
let claims: Claims = Claims::new(
audience.to_string(),
permissions,
app_data.config().issuer().to_string(),
grant_type,
custom_claims,
);

let user_info: UserInfo = UserInfo::new(app_data.config(), audience.to_string());
Expand All @@ -175,3 +177,69 @@ fn new_token_response(app_data: &AppData, audience: &str, grant_type: GrantType)

TokenResponse::new(access_token, id_token, None)
}

#[cfg(test)]
mod test {
use crate::{
config::Config,
model::{AppData, GrantType},
};

use super::new_token_response;

#[test]
fn generate_access_token_with_custom_claims() {
let config_string: &str = r#"
issuer = "https://prima.localauth0.com/"
[user_info]
name = "Local"
given_name = "Locie"
family_name = "Auth0"
gender = "none"
birthdate = "2022-02-11"
email = "[email protected]"
picture = "https://github.com/primait/localauth0/blob/6f71c9318250219a9d03fb72afe4308b8824aef7/web/assets/static/media/localauth0.png"
custom_fields = [
{ name = "address", value = { String = "github street" } },
{ name = "roles", value = { Vec = ["fake:auth"] } }
]
[[audience]]
name = "audience1"
permissions = ["audience1:permission1", "audience1:permission2"]
[[audience]]
name = "audience2"
permissions = ["audience2:permission2"]
[access_token]
custom_claims = [
{ name = "at_custom_claims_str", value = { String = "str" } }
]
"#;
let config: Config = toml::from_str(config_string).unwrap();
let app_data = AppData::new(config).unwrap();
let audience = "my-audience";
let grant_type = GrantType::AuthorizationCode;

let token_response = new_token_response(&app_data, audience, grant_type);

let access_token = token_response.access_token();
let jwks = app_data.jwks().get().unwrap();
let claims_json: serde_json::Value = jwks
.parse(access_token, &[audience])
.expect("failed to parse access_token");

assert_eq!(claims_json.get("aud").unwrap(), "my-audience");
assert!(claims_json.get("iat").is_some());
assert!(claims_json.get("exp").is_some());
assert!(claims_json.get("scope").is_some());
assert_eq!(claims_json.get("iss").unwrap(), "https://prima.localauth0.com/");
assert_eq!(claims_json.get("gty").unwrap(), "authorization_code");
assert!(claims_json.get("permissions").is_some());

assert_eq!(claims_json.get("at_custom_claims_str").unwrap(), "str");
}
}
50 changes: 46 additions & 4 deletions src/model/claims.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
use serde::{ser::SerializeMap, Deserialize, Serialize, Serializer};
use std::fmt::{Display, Formatter};

use serde::{Deserialize, Serialize};
use crate::config::{CustomField, CustomFieldValue};

#[derive(Serialize, Deserialize, Debug)]
#[derive(Debug, Deserialize)]
pub struct Claims {
aud: String,
iat: Option<i64>,
exp: Option<i64>,
scope: String,
iss: String,
gty: GrantType,
#[serde(default)]
permissions: Vec<String>,
// skip deserializing since deserialization from a jwt wouldn't match this struct
// a custom deserializer would be needed
#[serde(skip_deserializing)]
custom_claims: Vec<CustomField>,
}

impl Claims {
pub fn new(aud: String, permissions: Vec<String>, iss: String, gty: GrantType) -> Self {
pub fn new(
aud: String,
permissions: Vec<String>,
iss: String,
gty: GrantType,
custom_claims: Vec<CustomField>,
) -> Self {
Self {
aud,
iat: Some(chrono::Utc::now().timestamp()),
Expand All @@ -24,6 +34,7 @@ impl Claims {
iss,
gty,
permissions,
custom_claims,
}
}

Expand All @@ -42,6 +53,37 @@ impl Claims {
pub fn grant_type(&self) -> &GrantType {
&self.gty
}

#[cfg(test)]
pub fn custom_claims(&self) -> &Vec<CustomField> {
&&self.custom_claims
}
}

impl Serialize for Claims {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(None)?;

map.serialize_entry("aud", &self.aud)?;
map.serialize_entry("iat", &self.iat)?;
map.serialize_entry("exp", &self.exp)?;
map.serialize_entry("scope", &self.scope)?;
map.serialize_entry("iss", &self.iss)?;
map.serialize_entry("gty", &self.gty)?;
map.serialize_entry("permissions", &self.permissions)?;

for custom_claims in &self.custom_claims {
match custom_claims.value() {
CustomFieldValue::String(string) => map.serialize_entry(custom_claims.name(), &string),
CustomFieldValue::Vec(vec) => map.serialize_entry(custom_claims.name(), &vec),
}?;
}

map.end()
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
69 changes: 67 additions & 2 deletions src/model/jwks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,11 @@ fn generate_x509_cert(key_pair: &PKey<Private>) -> Result<X509, ErrorStack> {

#[cfg(test)]
mod tests {
use crate::config::CustomField;
use crate::error::Error;
use crate::model::jwks::JwksStore;
use crate::model::{Claims, GrantType, Jwk, Jwks};
use serde_json::json;

#[test]
fn its_possible_to_generate_jwks_and_parse_claims_using_given_jwks_test() {
Expand All @@ -232,17 +234,80 @@ mod tests {
vec![permission.to_string()],
issuer.to_string(),
gty.clone(),
vec![],
);

let jwt: String = random_jwk.encode(&claims).unwrap();

let result: Result<Claims, Error> = jwks.parse(jwt.as_ref(), &[audience]);

assert!(result.is_ok());

let claims: Claims = result.unwrap();
assert_eq!(claims.audience(), audience);
assert!(claims.has_permission(permission));
assert_eq!(claims.issuer(), issuer);
assert_eq!(claims.grant_type().to_string(), gty.to_string());
}

#[test]
fn use_custom_claims_test() {
let jwk_store: JwksStore = JwksStore::new().unwrap();
let audience: &str = "audience";
let permission: &str = "permission";
let issuer: &str = "issuer";
let gty: GrantType = GrantType::ClientCredentials;

let jwks: Jwks = jwk_store.get().unwrap();
let random_jwk: Jwk = jwks.random_jwk().unwrap();
let custom_claims: Vec<CustomField> = vec![
serde_json::from_value(json!({ "name": "at_custom_claims_str", "value": { "String": "my_str" } })).unwrap(),
serde_json::from_value(json!({"name": "at_custom_claims_vec", "value": {"Vec": ["foobar"]}})).unwrap(),
];

let claims: Claims = Claims::new(
audience.to_string(),
vec![permission.to_string()],
issuer.to_string(),
gty.clone(),
custom_claims.clone(),
);

let jwt: String = random_jwk.encode(&claims).unwrap();
let content = jwks
.parse::<serde_json::Value>(jwt.as_ref(), &[audience])
.expect("unable to parse jwt");
assert_eq!(content.get("at_custom_claims_str").unwrap(), "my_str");
let custom_claim_vec: Vec<String> =
serde_json::from_value(content.get("at_custom_claims_vec").unwrap().to_owned()).unwrap();
assert_eq!(custom_claim_vec, vec!["foobar".to_string()]);
}

#[test]
fn duplicated_custom_claim_keeps_the_last_one() {
let jwk_store: JwksStore = JwksStore::new().unwrap();
let audience: &str = "audience";
let permission: &str = "permission";
let issuer: &str = "issuer";
let gty: GrantType = GrantType::ClientCredentials;

let jwks: Jwks = jwk_store.get().unwrap();
let random_jwk: Jwk = jwks.random_jwk().unwrap();
let custom_claims: Vec<CustomField> = vec![
serde_json::from_value(json!({ "name": "at_custom_claims_str", "value": { "String": "my-str-1" } }))
.unwrap(),
serde_json::from_value(json!({ "name": "at_custom_claims_str", "value": { "String": "my-str-2" } }))
.unwrap(),
];

let claims: Claims = Claims::new(
audience.to_string(),
vec![permission.to_string()],
issuer.to_string(),
gty.clone(),
custom_claims.clone(),
);

let jwt: String = random_jwk.encode(&claims).unwrap();
let result = jwks.parse::<serde_json::Value>(jwt.as_ref(), &[audience]).unwrap();
assert_eq!(result.get("at_custom_claims_str").unwrap(), "my-str-2");
}
}
5 changes: 5 additions & 0 deletions src/model/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ impl TokenResponse {
token_type: BEARER.to_string(),
}
}

#[cfg(test)]
pub fn access_token(&self) -> &str {
&self.access_token
}
}

#[derive(Serialize)]
Expand Down

0 comments on commit 300b905

Please sign in to comment.