Skip to content

Commit

Permalink
Updated to use the latest base64 crate
Browse files Browse the repository at this point in the history
  • Loading branch information
mibes404 committed May 6, 2024
1 parent ddebcaf commit 0191ee5
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 70 deletions.
3 changes: 1 addition & 2 deletions examples/jwks_client.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#[cfg(feature = "remote-jwks")]
#[tokio::main]
async fn main() -> jwtk::Result<()> {
use std::time::Duration;

use jwtk::jwk::RemoteJwksVerifier;
use serde::Deserialize;
use serde_json::{Map, Value};
use std::time::Duration;

#[derive(Deserialize)]
struct Token {
Expand Down
19 changes: 10 additions & 9 deletions src/ecdsa.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use base64::Engine as _;
use foreign_types::ForeignTypeRef;
use openssl::{
bn::{BigNum, BigNumContext},
Expand All @@ -11,8 +12,8 @@ use openssl_sys::BN_bn2bin;
use smallvec::{smallvec, SmallVec};

use crate::{
jwk::Jwk, url_safe_trailing_bits, Error, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey,
VerificationKey,
jwk::Jwk, Error, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey, VerificationKey,
URL_SAFE_TRAILING_BITS,
};

#[non_exhaustive]
Expand Down Expand Up @@ -196,8 +197,8 @@ impl PublicKeyToJwk for EcdsaPrivateKey {
kty: "EC".into(),
use_: Some("sig".into()),
crv: Some(self.algorithm.curve_name().into()),
x: Some(base64::encode_config(x, url_safe_trailing_bits())),
y: Some(base64::encode_config(y, url_safe_trailing_bits())),
x: Some(URL_SAFE_TRAILING_BITS.encode(x)),
y: Some(URL_SAFE_TRAILING_BITS.encode(y)),
..Default::default()
})
}
Expand All @@ -211,9 +212,9 @@ impl PrivateKeyToJwk for EcdsaPrivateKey {
kty: "EC".into(),
use_: Some("sig".into()),
crv: Some(self.algorithm.curve_name().into()),
d: Some(base64::encode_config(d, url_safe_trailing_bits())),
x: Some(base64::encode_config(x, url_safe_trailing_bits())),
y: Some(base64::encode_config(y, url_safe_trailing_bits())),
d: Some(URL_SAFE_TRAILING_BITS.encode(d)),
x: Some(URL_SAFE_TRAILING_BITS.encode(x)),
y: Some(URL_SAFE_TRAILING_BITS.encode(y)),
..Default::default()
})
}
Expand Down Expand Up @@ -307,8 +308,8 @@ impl PublicKeyToJwk for EcdsaPublicKey {
kty: "EC".into(),
use_: Some("sig".into()),
crv: Some(self.algorithm.curve_name().into()),
x: Some(base64::encode_config(x, url_safe_trailing_bits())),
y: Some(base64::encode_config(y, url_safe_trailing_bits())),
x: Some(URL_SAFE_TRAILING_BITS.encode(x)),
y: Some(URL_SAFE_TRAILING_BITS.encode(y)),
..Default::default()
})
}
Expand Down
21 changes: 10 additions & 11 deletions src/eddsa.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use std::ptr;

use crate::{
jwk::Jwk, Error, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey, VerificationKey,
URL_SAFE_TRAILING_BITS,
};
use base64::Engine as _;
use foreign_types::ForeignType;
use openssl::{
error::ErrorStack,
pkey::{PKey, Private, Public},
sign::{Signer, Verifier},
};
use smallvec::SmallVec;

use crate::{
jwk::Jwk, url_safe_trailing_bits, Error, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey,
VerificationKey,
};
use std::ptr;

#[derive(Debug, Clone)]
pub struct Ed25519PrivateKey {
Expand Down Expand Up @@ -100,7 +99,7 @@ impl PublicKeyToJwk for Ed25519PrivateKey {
Ok(Jwk {
kty: "OKP".into(),
crv: Some("Ed25519".into()),
x: Some(base64::encode_config(bytes, url_safe_trailing_bits())),
x: Some(URL_SAFE_TRAILING_BITS.encode(bytes)),
..Jwk::default()
})
}
Expand All @@ -113,8 +112,8 @@ impl PrivateKeyToJwk for Ed25519PrivateKey {
Ok(Jwk {
kty: "OKP".into(),
crv: Some("Ed25519".into()),
d: Some(base64::encode_config(d, url_safe_trailing_bits())),
x: Some(base64::encode_config(x, url_safe_trailing_bits())),
d: Some(URL_SAFE_TRAILING_BITS.encode(d)),
x: Some(URL_SAFE_TRAILING_BITS.encode(x)),
..Jwk::default()
})
}
Expand Down Expand Up @@ -181,7 +180,7 @@ impl PublicKeyToJwk for Ed25519PublicKey {
Ok(Jwk {
kty: "OKP".into(),
crv: Some("Ed25519".into()),
x: Some(base64::encode_config(bytes, url_safe_trailing_bits())),
x: Some(URL_SAFE_TRAILING_BITS.encode(bytes)),
..Jwk::default()
})
}
Expand Down
38 changes: 16 additions & 22 deletions src/jwk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
//!
//! Only public keys are really supported for now.
use std::collections::{BTreeMap, HashMap};

use crate::{
ecdsa::{EcdsaAlgorithm, EcdsaPrivateKey, EcdsaPublicKey},
eddsa::{Ed25519PrivateKey, Ed25519PublicKey},
rsa::{RsaAlgorithm, RsaPrivateKey, RsaPublicKey},
some::SomePublicKey,
url_safe_trailing_bits, verify, verify_only, Error, Header, HeaderAndClaims, PublicKeyToJwk,
Result, SigningKey, SomePrivateKey, VerificationKey,
verify, verify_only, Error, Header, HeaderAndClaims, PublicKeyToJwk, Result, SigningKey,
SomePrivateKey, VerificationKey, URL_SAFE_TRAILING_BITS,
};
use base64::Engine as _;
use openssl::{
bn::BigNum,
hash::{hash, MessageDigest},
Expand All @@ -20,6 +19,7 @@ use openssl::{
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
use std::collections::{BTreeMap, HashMap};

// TODO: private key jwk.

Expand Down Expand Up @@ -80,8 +80,8 @@ impl Jwk {
match &*self.kty {
"RSA" => match (self.alg.as_deref(), &self.n, &self.e) {
(alg, Some(ref n), Some(ref e)) => {
let n = base64::decode_config(n, url_safe_trailing_bits())?;
let e = base64::decode_config(e, url_safe_trailing_bits())?;
let n = URL_SAFE_TRAILING_BITS.decode(n)?;
let e = URL_SAFE_TRAILING_BITS.decode(e)?;
// If `alg` is specified, the key will only verify
// signatures generated by ONLY this specific `alg`,
// otherwise it will verify signatures generated by ANY RSA
Expand All @@ -100,8 +100,8 @@ impl Jwk {
"EC" => match (self.crv.as_deref(), &self.x, &self.y) {
// For EC keys `crv` is required.
(Some(crv), Some(ref x), Some(ref y)) => {
let x = base64::decode_config(x, url_safe_trailing_bits())?;
let y = base64::decode_config(y, url_safe_trailing_bits())?;
let x = URL_SAFE_TRAILING_BITS.decode(x)?;
let y = URL_SAFE_TRAILING_BITS.decode(y)?;
let alg = EcdsaAlgorithm::from_curve_name(crv)?;
return Ok(SomePublicKey::Ecdsa(EcdsaPublicKey::from_coordinates(
&x, &y, alg,
Expand All @@ -111,7 +111,7 @@ impl Jwk {
},
"OKP" => match (self.crv.as_deref(), &self.x) {
(Some(crv), Some(ref x)) => {
let x = base64::decode_config(x, url_safe_trailing_bits())?;
let x = URL_SAFE_TRAILING_BITS.decode(x)?;
match crv {
"Ed25519" => {
return Ok(SomePublicKey::Ed25519(Ed25519PublicKey::from_bytes(&x)?));
Expand Down Expand Up @@ -139,10 +139,7 @@ impl Jwk {
match (self.d.as_deref(), self.n.as_deref(), self.e.as_deref()) {
(Some(d), Some(n), Some(e)) => {
fn decode(x: &str) -> Result<BigNum> {
Ok(BigNum::from_slice(&base64::decode_config(
x,
url_safe_trailing_bits(),
)?)?)
Ok(BigNum::from_slice(&URL_SAFE_TRAILING_BITS.decode(x)?)?)
}
let d = decode(d)?;
let n = decode(n)?;
Expand Down Expand Up @@ -185,17 +182,17 @@ impl Jwk {
) {
(Some(crv), Some(d), Some(x), Some(y)) => {
let alg = EcdsaAlgorithm::from_curve_name(crv)?;
let d = base64::decode_config(d, url_safe_trailing_bits())?;
let x = base64::decode_config(x, url_safe_trailing_bits())?;
let y = base64::decode_config(y, url_safe_trailing_bits())?;
let d = URL_SAFE_TRAILING_BITS.decode(d)?;
let x = URL_SAFE_TRAILING_BITS.decode(x)?;
let y = URL_SAFE_TRAILING_BITS.decode(y)?;
EcdsaPrivateKey::from_private_components(alg, &d, &x, &y).map(Into::into)
}
_ => Err(Error::UnsupportedOrInvalidKey),
}
}
"OKP" => match (self.crv.as_deref(), self.d.as_deref()) {
(Some("Ed25519"), Some(d)) => {
let d = base64::decode_config(d, url_safe_trailing_bits())?;
let d = URL_SAFE_TRAILING_BITS.decode(d)?;
Ed25519PrivateKey::from_bytes(&d).map(Into::into)
}
_ => Err(Error::UnsupportedOrInvalidKey),
Expand Down Expand Up @@ -260,10 +257,7 @@ impl Jwk {

/// Get key thumbprint with SHA-256, base64url-encoded.
pub fn get_thumbprint_sha256_base64(&self) -> Result<String> {
Ok(base64::encode_config(
self.get_thumbprint_sha256()?,
url_safe_trailing_bits(),
))
Ok(URL_SAFE_TRAILING_BITS.encode(self.get_thumbprint_sha256()?))
}
}

Expand Down Expand Up @@ -341,7 +335,7 @@ impl JwkSetVerifier {

let mut header = parts.next().ok_or(Error::InvalidToken)?.as_bytes();

let header_r = base64::read::DecoderReader::new(&mut header, url_safe_trailing_bits());
let header_r = base64::read::DecoderReader::new(&mut header, &URL_SAFE_TRAILING_BITS);
let header: Header = serde_json::from_reader(header_r)?;

if let Some(kid) = header.kid {
Expand Down
37 changes: 21 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#![doc = include_str!("../README.md")]

use base64::Engine as _;
use base64::{
alphabet,
engine::{general_purpose::NO_PAD, GeneralPurpose},
};
use openssl::error::ErrorStack;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::{Map, Value};
Expand Down Expand Up @@ -232,10 +237,10 @@ impl HeaderAndClaims<Map<String, Value>> {
}
}

#[inline(always)]
fn url_safe_trailing_bits() -> base64::Config {
base64::URL_SAFE_NO_PAD.decode_allow_trailing_bits(true)
}
pub const URL_SAFE_TRAILING_BITS: GeneralPurpose = GeneralPurpose::new(
&alphabet::URL_SAFE,
NO_PAD.with_decode_allow_trailing_bits(true),
);

/// Encode and sign this header and claims with the signing key.
///
Expand All @@ -252,12 +257,12 @@ pub fn sign<ExtraClaims: Serialize>(
claims.set_kid(kid);
}

let mut w = base64::write::EncoderStringWriter::new(url_safe_trailing_bits());
let mut w = base64::write::EncoderStringWriter::new(&URL_SAFE_TRAILING_BITS);
serde_json::to_writer(&mut w, &claims.header)?;

let mut buf = w.into_inner();
buf.push('.');
let mut w = base64::write::EncoderStringWriter::from(buf, url_safe_trailing_bits());
let mut w = base64::write::EncoderStringWriter::from_consumer(buf, &URL_SAFE_TRAILING_BITS);

serde_json::to_writer(&mut w, &claims.claims)?;
let mut buf = w.into_inner();
Expand All @@ -266,7 +271,7 @@ pub fn sign<ExtraClaims: Serialize>(

buf.push('.');

let mut w = base64::write::EncoderStringWriter::from(buf, url_safe_trailing_bits());
let mut w = base64::write::EncoderStringWriter::from_consumer(buf, &URL_SAFE_TRAILING_BITS);
w.write_all(&sig)?;
Ok(w.into_inner())
}
Expand Down Expand Up @@ -315,10 +320,10 @@ pub fn verify_only<ExtraClaims: DeserializeOwned>(
return Err(Error::InvalidToken);
}

let header_r = base64::read::DecoderReader::new(&mut header, url_safe_trailing_bits());
let header_r = base64::read::DecoderReader::new(&mut header, &URL_SAFE_TRAILING_BITS);
let header: Header = serde_json::from_reader(header_r)?;

let sig = base64::decode_config(sig, url_safe_trailing_bits())?;
let sig = URL_SAFE_TRAILING_BITS.decode(sig)?;

// Verify the signature.
k.verify(
Expand All @@ -327,7 +332,7 @@ pub fn verify_only<ExtraClaims: DeserializeOwned>(
&header.alg,
)?;

let payload_r = base64::read::DecoderReader::new(&mut payload, url_safe_trailing_bits());
let payload_r = base64::read::DecoderReader::new(&mut payload, &URL_SAFE_TRAILING_BITS);
let claims: Claims<ExtraClaims> = serde_json::from_reader(payload_r)?;

Ok(HeaderAndClaims { header, claims })
Expand All @@ -348,10 +353,10 @@ pub fn decode_without_verify<ExtraClaims: DeserializeOwned>(
return Err(Error::InvalidToken);
}

let header_r = base64::read::DecoderReader::new(&mut header, url_safe_trailing_bits());
let header_r = base64::read::DecoderReader::new(&mut header, &URL_SAFE_TRAILING_BITS);
let header: Header = serde_json::from_reader(header_r)?;

let payload_r = base64::read::DecoderReader::new(&mut payload, url_safe_trailing_bits());
let payload_r = base64::read::DecoderReader::new(&mut payload, &URL_SAFE_TRAILING_BITS);
let claims: Claims<ExtraClaims> = serde_json::from_reader(payload_r)?;

Ok(HeaderAndClaims { header, claims })
Expand Down Expand Up @@ -532,12 +537,12 @@ mod tests {

#[test]
fn claim_deserialization() {
let mut json = r#"eyJpYXQiOjEuNjkyMTkwMTI1RTksImV4cCI6MS42OTIxOTM3MjVFOSwiYW50aUNzcmZUb2tlbiI6bnVsbCwic3ViIjoiYTM5ZmZjNWUtNjc5ZC00YjAzLWI5YmYtYTliZjEzNDk4NGYzIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDozOTk5L2F1dGgiLCJzZXNzaW9uSGFuZGxlIjoiNTAyMWQ2MTQtYzFmNi00ZTZkLWI1NjktZGQxN2Q0N2EyOWI0IiwicGFyZW50UmVmcmVzaFRva2VuSGFzaDEiOm51bGwsInJlZnJlc2hUb2tlbkhhc2gxIjoiNTZiMjcxZDcxNGRlMzg3M2UwMmIyZjAyYTJiZDcyYWJjZDIyZDM0NGZlZjE2YTJkMWJjYmM1NGU2YWUxN2M3OCJ9"#.as_bytes();
let mut json = r"eyJpYXQiOjEuNjkyMTkwMTI1RTksImV4cCI6MS42OTIxOTM3MjVFOSwiYW50aUNzcmZUb2tlbiI6bnVsbCwic3ViIjoiYTM5ZmZjNWUtNjc5ZC00YjAzLWI5YmYtYTliZjEzNDk4NGYzIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDozOTk5L2F1dGgiLCJzZXNzaW9uSGFuZGxlIjoiNTAyMWQ2MTQtYzFmNi00ZTZkLWI1NjktZGQxN2Q0N2EyOWI0IiwicGFyZW50UmVmcmVzaFRva2VuSGFzaDEiOm51bGwsInJlZnJlc2hUb2tlbkhhc2gxIjoiNTZiMjcxZDcxNGRlMzg3M2UwMmIyZjAyYTJiZDcyYWJjZDIyZDM0NGZlZjE2YTJkMWJjYmM1NGU2YWUxN2M3OCJ9".as_bytes();

let r = base64::read::DecoderReader::new(&mut json, url_safe_trailing_bits());
let r = base64::read::DecoderReader::new(&mut json, &URL_SAFE_TRAILING_BITS);

let claims: Claims<Value> = serde_json::from_reader(r).unwrap();
assert_eq!(claims.iat, Some(Duration::from_secs(1692190125)));
assert_eq!(claims.exp, Some(Duration::from_secs(1692193725)));
assert_eq!(claims.iat, Some(Duration::from_secs(1_692_190_125)));
assert_eq!(claims.exp, Some(Duration::from_secs(1_692_193_725)));
}
}
20 changes: 10 additions & 10 deletions src/rsa.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
use crate::{
jwk::Jwk, Error, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey, VerificationKey,
URL_SAFE_TRAILING_BITS,
};
use base64::Engine as _;
/// RSASSA-PKCS1-v1_5 using SHA-256.
use openssl::{
bn::BigNum,
Expand All @@ -8,11 +13,6 @@ use openssl::{
};
use smallvec::SmallVec;

use crate::{
jwk::Jwk, url_safe_trailing_bits, Error, PrivateKeyToJwk, PublicKeyToJwk, Result, SigningKey,
VerificationKey,
};

/// RSA signature algorithms.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -161,7 +161,7 @@ impl PrivateKeyToJwk for RsaPrivateKey {
let dq = rsa.dmq1().map(|dq| dq.to_vec());
let qi = rsa.iqmp().map(|qi| qi.to_vec());
fn encode(x: &[u8]) -> String {
base64::encode_config(x, url_safe_trailing_bits())
URL_SAFE_TRAILING_BITS.encode(x)
}
Ok(Jwk {
kty: "RSA".into(),
Expand Down Expand Up @@ -194,8 +194,8 @@ impl PublicKeyToJwk for RsaPrivateKey {
Some(self.algorithm.name().into())
},
use_: Some("sig".into()),
n: Some(base64::encode_config(self.n()?, url_safe_trailing_bits())),
e: Some(base64::encode_config(self.e()?, url_safe_trailing_bits())),
n: Some(URL_SAFE_TRAILING_BITS.encode(self.n()?)),
e: Some(URL_SAFE_TRAILING_BITS.encode(self.e()?)),
..Jwk::default()
})
}
Expand Down Expand Up @@ -265,8 +265,8 @@ impl PublicKeyToJwk for RsaPublicKey {
kty: "RSA".into(),
alg: self.algorithm.map(|alg| alg.name().to_string()),
use_: Some("sig".into()),
n: Some(base64::encode_config(self.n()?, url_safe_trailing_bits())),
e: Some(base64::encode_config(self.e()?, url_safe_trailing_bits())),
n: Some(URL_SAFE_TRAILING_BITS.encode(self.n()?)),
e: Some(URL_SAFE_TRAILING_BITS.encode(self.e()?)),
..Jwk::default()
})
}
Expand Down

0 comments on commit 0191ee5

Please sign in to comment.