diff --git a/Cargo.toml b/Cargo.toml index 33a0328..737c4af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ categories = [ [dependencies] base64ct = { version = "1.6.0", features = ["std"] } +bytes = { version = "1.5" } chrono = { version = "0.4.24", features = ["serde"] } elliptic-curve = { version = "0.13.4", features = [ "pkcs8", diff --git a/README.md b/README.md index 2ea01e1..70012a4 100644 --- a/README.md +++ b/README.md @@ -72,17 +72,27 @@ stateDiagram-v2 To create a simple JWT, you'll need to provide an encryption key. This example uses the RSA encrption key defined in Appendix A.2 of [RFC 7515][JWS], don't re use it! -This example is reproduced from [`examples/rfc7515a2.rs`][/examples/rfc7515a2.rs] in the repository, -and can be run with `cargo run --example rfc7515a2`. +This example is reproduced from [`examples/rfc7515a2.rs`](./examples/rfc7515a2.rs) in the repository, +and can be run with `cargo run --example rfc7515-a2`. ```rust +use jaws::Compact; +use std::ops::Deref; + // JAWS provides JWT format for printing JWTs in a style similar to the example above, // which is directly inspired by the way the ACME standard shows JWTs. use jaws::JWTFormat; // JAWS provides a single token type which is generic over the state of the token. +// The states are defined in the `state` module, and are used to track the +// signing and verification status. use jaws::Token; +use jaws::algorithms::rsa::RsaPkcs1v15Verify; +// The unverified token state, used like `Token<.., Unverified<..>, ..>`. +// It is generic over the type of the custom header parameters. +use jaws::token::Unverified; + // JAWS provides type-safe support for JWT claims. use jaws::{Claims, RegisteredClaims}; @@ -94,10 +104,8 @@ use rsa::pkcs8::DecodePrivateKey; // function, so we get it here from the `sha2` crate in the RustCrypto suite. use sha2::Sha256; -// JAWS provides thin algorithm wrappers for algorithms which accept -// parameters beyond just the encryption or singing key. For example, the `RS256` -// algorithm accepts a hash function, but is otherwise identical to the other -// `RS*` hash functions. +// This is an alias for the RSA PKCS#1 v1.5 signing algorithm, which is +// implemented in the rsa crate as `rsa::pkcs1v15::SigningKey`. use jaws::algorithms::rsa::RsaPkcs1v15; // Using serde_json allows us to quickly construct a serializable payload, @@ -143,9 +151,25 @@ fn main() -> Result<(), Box> { // we provide the `typ` header, which is optional in the JWT spec. *token.header_mut().r#type() = Some("JWT".to_string()); + // We can also ask that some fields be derived from the signing key, for example, + // this will derive the JWK field in the header from the signing key. + token.header_mut().key().derived(); + + println!("Initial JWT"); + + // Initially the JWT has no defined signature: + println!("JWT:"); + println!("{}", token.formatted()); + // Sign the token with the algorithm, and print the result. let signed = token.sign(&alg).unwrap(); + println!("Signed JWT"); + + println!("JWT:"); + println!("{}", signed.formatted()); + println!("Token: {}", signed.rendered().unwrap()); + // We can't modify the token after signing it (that would change the signature) // but we can access fields and read from them: println!( @@ -154,12 +178,64 @@ fn main() -> Result<(), Box> { signed.header().algorithm(), ); - println!("Token: {}", signed.rendered().unwrap()); + // We can also verify tokens. + let token: Token, Unverified<()>, Compact> = + signed.rendered().unwrap().parse().unwrap(); + + println!("Parsed JWT"); + + // Unverified tokens can be printed for debugging, but there is deliberately + // no access to the payload, only to the header fields. println!("JWT:"); - println!("{}", signed.formatted()); + println!("{}", token.formatted()); + + // We can use the JWK to verify that the token is signed with the correct key. + let hdr = token.header(); + let jwk = hdr.key().unwrap(); + let key = rsa_jwk_reader::rsa_pub(&serde_json::to_value(jwk).unwrap()); + + assert_eq!(&key, alg.as_ref().deref()); + + let alg: RsaPkcs1v15Verify = RsaPkcs1v15Verify::new_with_prefix(key); + + // We can't access the claims until we verify the token. + let verified = token.verify(&alg).unwrap(); + + println!("Verified JWT"); + println!("JWT:"); + println!("{}", verified.formatted()); + println!( + "Payload: \n{}", + serde_json::to_string_pretty(&verified.payload()).unwrap() + ); Ok(()) } + +mod rsa_jwk_reader { + use base64ct::Encoding; + + fn strip_whitespace(s: &str) -> String { + s.chars().filter(|c| !c.is_whitespace()).collect() + } + + fn to_biguint(v: &serde_json::Value) -> Option { + let val = strip_whitespace(v.as_str()?); + Some(rsa::BigUint::from_bytes_be( + base64ct::Base64UrlUnpadded::decode_vec(&val) + .ok()? + .as_slice(), + )) + } + + pub(crate) fn rsa_pub(key: &serde_json::Value) -> rsa::RsaPublicKey { + let n = to_biguint(&key["n"]).expect("decode n"); + let e = to_biguint(&key["e"]).expect("decode e"); + + rsa::RsaPublicKey::new(n, e).expect("valid key parameters") + } +} + ``` ## Philosophy diff --git a/examples/rfc7515a2.rs b/examples/rfc7515a2.rs index ffc1b3c..716e28d 100644 --- a/examples/rfc7515a2.rs +++ b/examples/rfc7515a2.rs @@ -1,12 +1,20 @@ use jaws::Compact; +use std::ops::Deref; + // JAWS provides JWT format for printing JWTs in a style similar to the example above, // which is directly inspired by the way the ACME standard shows JWTs. use jaws::JWTFormat; -// JAWS provides strongly typed support for tokens, so we can only build an UnsignedToken, -// which we can sign to create a SignedToken or a plain Token. +// JAWS provides a single token type which is generic over the state of the token. +// The states are defined in the `state` module, and are used to track the +// signing and verification status. use jaws::Token; +use jaws::algorithms::rsa::RsaPkcs1v15Verify; +// The unverified token state, used like `Token<.., Unverified<..>, ..>`. +// It is generic over the type of the custom header parameters. +use jaws::token::Unverified; + // JAWS provides type-safe support for JWT claims. use jaws::{Claims, RegisteredClaims}; @@ -18,10 +26,8 @@ use rsa::pkcs8::DecodePrivateKey; // function, so we get it here from the `sha2` crate in the RustCrypto suite. use sha2::Sha256; -// JAWS provides thin algorithm wrappers for algorithms which accept -// parameters beyond just the encryption or singing key. For example, the `RS256` -// algorithm accepts a hash function, but is otherwise identical to the other -// `RS*` hash functions. +// This is an alias for the RSA PKCS#1 v1.5 signing algorithm, which is +// implemented in the rsa crate as `rsa::pkcs1v15::SigningKey`. use jaws::algorithms::rsa::RsaPkcs1v15; // Using serde_json allows us to quickly construct a serializable payload, @@ -62,14 +68,30 @@ fn main() -> Result<(), Box> { // The unit type can be used here because it implements [serde::Serialize], // but a custom type could be passed if we wanted to have custom header // fields. - let mut token = Token::new((), claims, Compact); + let mut token = Token::compact((), claims); // We can modify the headers freely before signing the JWT. In this case, // we provide the `typ` header, which is optional in the JWT spec. *token.header_mut().r#type() = Some("JWT".to_string()); + // We can also ask that some fields be derived from the signing key, for example, + // this will derive the JWK field in the header from the signing key. + token.header_mut().key().derived(); + + println!("Initial JWT"); + + // Initially the JWT has no defined signature: + println!("JWT:"); + println!("{}", token.formatted()); + // Sign the token with the algorithm, and print the result. let signed = token.sign(&alg).unwrap(); + println!("Signed JWT"); + + println!("JWT:"); + println!("{}", signed.formatted()); + println!("Token: {}", signed.rendered().unwrap()); + // We can't modify the token after signing it (that would change the signature) // but we can access fields and read from them: println!( @@ -78,9 +100,60 @@ fn main() -> Result<(), Box> { signed.header().algorithm(), ); - println!("Token: {}", signed.rendered().unwrap()); + // We can also verify tokens. + let token: Token, Unverified<()>, Compact> = + signed.rendered().unwrap().parse().unwrap(); + + println!("Parsed JWT"); + + // Unverified tokens can be printed for debugging, but there is deliberately + // no access to the payload, only to the header fields. println!("JWT:"); - println!("{}", signed.formatted()); + println!("{}", token.formatted()); + + // We can use the JWK to verify that the token is signed with the correct key. + let hdr = token.header(); + let jwk = hdr.key().unwrap(); + let key = rsa_jwk_reader::rsa_pub(&serde_json::to_value(jwk).unwrap()); + + assert_eq!(&key, alg.as_ref().deref()); + + let alg: RsaPkcs1v15Verify = RsaPkcs1v15Verify::new_with_prefix(key); + + // We can't access the claims until we verify the token. + let verified = token.verify(&alg).unwrap(); + + println!("Verified JWT"); + println!("JWT:"); + println!("{}", verified.formatted()); + println!( + "Payload: \n{}", + serde_json::to_string_pretty(&verified.payload()).unwrap() + ); Ok(()) } + +mod rsa_jwk_reader { + use base64ct::Encoding; + + fn strip_whitespace(s: &str) -> String { + s.chars().filter(|c| !c.is_whitespace()).collect() + } + + fn to_biguint(v: &serde_json::Value) -> Option { + let val = strip_whitespace(v.as_str()?); + Some(rsa::BigUint::from_bytes_be( + base64ct::Base64UrlUnpadded::decode_vec(&val) + .ok()? + .as_slice(), + )) + } + + pub(crate) fn rsa_pub(key: &serde_json::Value) -> rsa::RsaPublicKey { + let n = to_biguint(&key["n"]).expect("decode n"); + let e = to_biguint(&key["e"]).expect("decode e"); + + rsa::RsaPublicKey::new(n, e).expect("valid key parameters") + } +} diff --git a/src/algorithms/ecdsa.rs b/src/algorithms/ecdsa.rs index 7d76f3c..3328553 100644 --- a/src/algorithms/ecdsa.rs +++ b/src/algorithms/ecdsa.rs @@ -74,7 +74,9 @@ use ::ecdsa::{ PrimeCurve, SignatureSize, }; use base64ct::Encoding; +use bytes::BytesMut; use digest::generic_array::ArrayLength; +use ecdsa::{hazmat::VerifyPrimitive, Signature, VerifyingKey}; use elliptic_curve::{ ops::Invert, sec1::{Coordinates, FromEncodedPoint, ModulusSize, ToEncodedPoint}, @@ -92,6 +94,7 @@ pub use p384::NistP384; #[cfg(feature = "p521")] pub use p521::NistP521; +use signature::Verifier; impl crate::key::JWKeyType for PublicKey where @@ -109,7 +112,7 @@ where FieldBytesSize: ModulusSize, { fn parameters(&self) -> Vec<(String, serde_json::Value)> { - let mut params = Vec::with_capacity(2); + let mut params = Vec::with_capacity(3); params.push(( "crv".to_owned(), @@ -144,6 +147,31 @@ where self.public_key().parameters() } } + +impl crate::key::SerializeJWK for ecdsa::VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, +{ + fn parameters(&self) -> Vec<(String, serde_json::Value)> { + PublicKey::::from(self).parameters() + } +} + +impl crate::key::JWKeyType for ecdsa::VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, +{ + const KEY_TYPE: &'static str = "EC"; +} + impl crate::key::SerializeJWK for ecdsa::SigningKey where C: PrimeCurve + CurveArithmetic + JwkParameters, @@ -153,7 +181,7 @@ where FieldBytesSize: ModulusSize, { fn parameters(&self) -> Vec<(String, serde_json::Value)> { - todo!() + self.verifying_key().parameters() } } @@ -223,6 +251,60 @@ where } } +impl super::VerifyAlgorithm for VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters + ecdsa::hazmat::DigestPrimitive, + ::AffinePoint: VerifyPrimitive, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + MaxSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, + VerifyingKey: super::Algorithm>, + as Add>::Output: Add + ArrayLength, +{ + type Error = ecdsa::Error; + type Key = VerifyingKey; + + fn verify( + &self, + header: &[u8], + payload: &[u8], + signature: &[u8], + ) -> Result { + let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); + message.extend_from_slice(header); + message.extend_from_slice(b"."); + message.extend_from_slice(payload); + + let signature = ecdsa::Signature::try_from(signature)?; + >>::verify(self, message.as_ref(), &signature)?; + Ok(signature.into()) + } + + fn key(&self) -> &Self::Key { + self + } +} + +#[cfg(feature = "p256")] +impl super::Algorithm for VerifyingKey { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::ES256; + type Signature = ecdsa::SignatureBytes; +} + +#[cfg(feature = "p384")] +impl super::Algorithm for VerifyingKey { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::ES384; + type Signature = ecdsa::SignatureBytes; +} + +#[cfg(feature = "p521")] +impl super::Algorithm for VerifyingKey { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::ES512; + type Signature = ecdsa::SignatureBytes; +} + #[cfg(all(test, feature = "p256"))] mod test { diff --git a/src/algorithms/hmac.rs b/src/algorithms/hmac.rs index fd9b616..51259d6 100644 --- a/src/algorithms/hmac.rs +++ b/src/algorithms/hmac.rs @@ -2,9 +2,10 @@ //! //! Based on the [hmac](https://crates.io/crates/hmac) crate. -use std::convert::Infallible; +use std::{convert::Infallible, marker::PhantomData}; use base64ct::Encoding; +use bytes::BytesMut; use digest::Mac; use hmac::SimpleHmac; @@ -81,7 +82,7 @@ where D: digest::Digest + digest::core_api::BlockSizeUser, { key: HmacKey, - digest: hmac::SimpleHmac, + _digest: PhantomData, } impl Hmac @@ -93,8 +94,10 @@ where /// /// Signing keys are arbitrary bytes. pub fn new(key: HmacKey) -> Self { - let digest = SimpleHmac::new_from_slice(key.as_ref()).expect("Valid key"); - Self { key, digest } + Self { + key, + _digest: PhantomData, + } } /// Reference to the HMAC key. @@ -133,8 +136,8 @@ where fn sign(&self, header: &str, payload: &str) -> Result { // Create a new, one-shot digest for this signature. - let mut digest = self.digest.clone(); - digest.reset(); + let mut digest: SimpleHmac = + SimpleHmac::new_from_slice(self.key.as_ref()).expect("Valid key"); let message = format!("{}.{}", header, payload); digest.update(message.as_bytes()); Ok(digest.finalize().into_bytes()) @@ -145,6 +148,43 @@ where } } +impl super::VerifyAlgorithm for Hmac +where + D: digest::Digest + + digest::Reset + + digest::core_api::BlockSizeUser + + digest::FixedOutput + + digest::core_api::CoreProxy + + Clone, + Hmac: super::Algorithm>>, +{ + type Error = digest::MacError; + type Key = HmacKey; + + fn verify( + &self, + header: &[u8], + payload: &[u8], + signature: &[u8], + ) -> Result { + // Create a new, one-shot digest for this signature. + let mut digest: SimpleHmac = + SimpleHmac::new_from_slice(self.key.as_ref()).expect("Valid key"); + let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); + message.extend_from_slice(header); + message.extend_from_slice(b"."); + message.extend_from_slice(payload); + + digest.update(message.as_ref()); + digest.clone().verify(signature.into())?; + Ok(digest.finalize().into_bytes()) + } + + fn key(&self) -> &Self::Key { + &self.key + } +} + #[cfg(test)] mod test { use crate::algorithms::SigningAlgorithm; diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 8b017e7..85f5499 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -69,6 +69,21 @@ pub enum AlgorithmIdentifier { None, } +impl AlgorithmIdentifier { + /// Return whether this algorithm is available for signing. + pub fn available(&self) -> bool { + match self { + Self::None => true, + + Self::HS256 | Self::HS384 | Self::HS512 => cfg!(feature = "hmac"), + Self::RS256 | Self::RS384 | Self::RS512 => cfg!(feature = "rsa"), + Self::ES256 | Self::ES384 | Self::ES512 => cfg!(feature = "ecdsa"), + Self::EdDSA => cfg!(feature = "ed25519"), + Self::PS256 | Self::PS384 | Self::PS512 => cfg!(feature = "rsa"), + } + } +} + /// A trait to associate an alogritm identifier with an algorithm. /// /// Algorithm identifiers are used in JWS and JWE to indicate how a token is signed or encrypted. @@ -118,8 +133,8 @@ pub trait VerifyAlgorithm: Algorithm { /// and payload. fn verify( &self, - header: &str, - payload: &str, + header: &[u8], + payload: &[u8], signature: &[u8], ) -> Result; @@ -133,9 +148,9 @@ pub trait VerifyAlgorithm: Algorithm { /// on the heap. It is used to store the signature of a JWT before it is verified, /// or if a signature has a variable length. #[derive(Debug, Clone, PartialEq, Eq, Hash, zeroize::Zeroize, zeroize::ZeroizeOnDrop)] -pub struct Signature(Vec); +pub struct SignatureBytes(Vec); -impl Signature { +impl SignatureBytes { /// Add to this signature from a byte slice. pub fn extend_from_slice(&mut self, other: &[u8]) { self.0.extend_from_slice(other); @@ -143,24 +158,24 @@ impl Signature { /// Create a new signature with the given capacity. pub fn with_capacity(capacity: usize) -> Self { - Signature(Vec::with_capacity(capacity)) + SignatureBytes(Vec::with_capacity(capacity)) } } -impl AsRef<[u8]> for Signature { +impl AsRef<[u8]> for SignatureBytes { fn as_ref(&self) -> &[u8] { &self.0 } } -impl From<&[u8]> for Signature { +impl From<&[u8]> for SignatureBytes { fn from(bytes: &[u8]) -> Self { - Signature(bytes.to_vec()) + SignatureBytes(bytes.to_vec()) } } -impl From> for Signature { +impl From> for SignatureBytes { fn from(bytes: Vec) -> Self { - Signature(bytes) + SignatureBytes(bytes) } } diff --git a/src/algorithms/rsa.rs b/src/algorithms/rsa.rs index 78e22d2..55d1ed9 100644 --- a/src/algorithms/rsa.rs +++ b/src/algorithms/rsa.rs @@ -16,7 +16,7 @@ //! This algorithm is used to sign and verify JSON Web Tokens using the RSASSA-PSS. use base64ct::{Base64UrlUnpadded, Encoding}; -use rsa::pkcs1v15::SigningKey; +use bytes::BytesMut; use rsa::rand_core::OsRng; use rsa::signature::RandomizedSigner; use rsa::PublicKeyParts; @@ -49,7 +49,8 @@ impl crate::key::SerializeJWK for rsa::RsaPrivateKey { } /// Alogrithm wrapper for the Digital Signature with RSASSA-PKCS1-v1_5 algorithm. -pub type RsaPkcs1v15 = SigningKey; +pub type RsaPkcs1v15 = rsa::pkcs1v15::SigningKey; +pub type RsaPkcs1v15Verify = rsa::pkcs1v15::VerifyingKey; impl super::SigningAlgorithm for RsaPkcs1v15 where @@ -69,6 +70,38 @@ where } } +impl super::VerifyAlgorithm for RsaPkcs1v15Verify +where + D: digest::Digest, + RsaPkcs1v15Verify: super::Algorithm + Clone, +{ + type Error = signature::Error; + + type Key = rsa::RsaPublicKey; + + fn verify( + &self, + header: &[u8], + payload: &[u8], + signature: &[u8], + ) -> Result { + use rsa::signature::Verifier; + let signature = rsa::pkcs1v15::Signature::try_from(signature).unwrap(); + + let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); + message.extend_from_slice(header); + message.extend_from_slice(b"."); + message.extend_from_slice(payload); + + >::verify(self, message.as_ref(), &signature)?; + Ok(signature) + } + + fn key(&self) -> &Self::Key { + self.as_ref() + } +} + impl super::Algorithm for RsaPkcs1v15 { const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS256; type Signature = rsa::pkcs1v15::Signature; @@ -84,6 +117,21 @@ impl super::Algorithm for RsaPkcs1v15 { type Signature = rsa::pkcs1v15::Signature; } +impl super::Algorithm for RsaPkcs1v15Verify { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS256; + type Signature = rsa::pkcs1v15::Signature; +} + +impl super::Algorithm for RsaPkcs1v15Verify { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS384; + type Signature = rsa::pkcs1v15::Signature; +} + +impl super::Algorithm for RsaPkcs1v15Verify { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS512; + type Signature = rsa::pkcs1v15::Signature; +} + /// Algorithm wrapper for RSA-PSS signatures, using [rsa::pss::BlindedSigningKey]. pub type RsaPSSKey = rsa::pss::BlindedSigningKey; diff --git a/src/base64data.rs b/src/base64data.rs index d747358..4210ab3 100644 --- a/src/base64data.rs +++ b/src/base64data.rs @@ -7,11 +7,27 @@ use std::fmt::Write; use std::marker::PhantomData; use base64ct::Encoding; -use serde::{de, ser, Serialize}; +use bytes::Bytes; +use serde::{ + de::{self, DeserializeOwned}, + ser, Serialize, +}; #[cfg(feature = "fmt")] use super::fmt::{self, IndentWriter}; +#[derive(Debug, thiserror::Error)] +pub enum DecodeError { + #[error(transparent)] + Base64(#[from] base64ct::Error), + + #[error(transparent)] + Json(#[from] serde_json::Error), + + #[error("data is not valid: {0}")] + InvalidData(#[source] Box), +} + /// Wrapper type to indicate that the inner type should be serialized /// as bytes with a Base64 URL-safe encoding. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -26,6 +42,18 @@ where } } +impl Base64Data +where + T: TryFrom>, + T::Error: std::error::Error + Send + Sync + 'static, +{ + pub(crate) fn parse(value: &str) -> Result { + let data = base64ct::Base64UrlUnpadded::decode_vec(value)?; + let data = T::try_from(data).map_err(|err| DecodeError::InvalidData(err.into()))?; + Ok(Base64Data(data)) + } +} + impl From for Base64Data { fn from(value: T) -> Self { Base64Data(value) @@ -112,6 +140,16 @@ where #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Base64JSON(pub T); +impl Base64JSON { + pub fn new(value: T) -> Self { + Base64JSON(value) + } + + pub fn into_inner(self) -> T { + self.0 + } +} + impl Base64JSON where T: Serialize, @@ -120,6 +158,32 @@ where let inner = serde_json::to_vec(&self.0)?; Ok(base64ct::Base64UrlUnpadded::encode_string(&inner)) } + + pub(crate) fn serialized_bytes(&self) -> Result { + self.serialized_value().map(Bytes::from) + } +} + +pub(crate) struct ParsedBase64JSON { + pub(crate) data: T, + pub(crate) bytes: Bytes, +} + +impl Base64JSON +where + T: DeserializeOwned, +{ + pub(crate) fn parse(raw: &str) -> Result, DecodeError> + where + T: de::DeserializeOwned, + { + let data = base64ct::Base64UrlUnpadded::decode_vec(raw)?; + let value = serde_json::from_slice(&data)?; + Ok(ParsedBase64JSON { + data: value, + bytes: Bytes::from(raw.to_owned()), + }) + } } impl AsRef for Base64JSON { diff --git a/src/claims.rs b/src/claims.rs index d7e771f..64e5bce 100644 --- a/src/claims.rs +++ b/src/claims.rs @@ -22,13 +22,13 @@ use crate::fmt; /// request. /// /// Fields which are `None` are left out of the regsitered header. -#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct RegisteredClaims { /// Claim issuer identifies the principal that issued the /// JWT. The processing of this claim is generally application specific. /// The "iss" value is a case-sensitive string containing a StringOrURI /// value. Use of this claim is OPTIONAL. - #[serde(rename = "iss", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "iss", skip_serializing_if = "Option::is_none")] pub issuer: Option, /// Claim subject identifies the principal that is the @@ -38,7 +38,7 @@ pub struct RegisteredClaims, /// The "aud" (audience) claim identifies the recipients that the JWT is @@ -51,7 +51,7 @@ pub struct RegisteredClaims, /// The "exp" (expiration time) claim identifies the expiration time on or @@ -62,6 +62,7 @@ pub struct RegisteredClaims>, /// The "jti" (JWT ID) claim provides a unique identifier for the JWT. The identifier value MUST be assigned in a manner that ensures that there is a negligible probability that the same value will be accidentally assigned to a different data object; if the application uses multiple issuers, collisions MUST be prevented among values produced by different issuers as well. The "jti" claim can be used to prevent the JWT from being replayed. The "jti" value is a case- sensitive string. Use of this claim is OPTIONAL. - #[serde(rename = "jti", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "jti", skip_serializing_if = "Option::is_none")] pub token_id: Option, } +impl Default for RegisteredClaims { + fn default() -> Self { + Self { + issuer: None, + subject: None, + audience: None, + expiration: None, + not_before: None, + issued_at: None, + token_id: None, + } + } +} + #[cfg(feature = "fmt")] impl fmt::JWTFormat for RegisteredClaims where @@ -117,10 +134,10 @@ where /// They consist of "registered" header values, specified in RFC 7519, /// and a set of custom claims, which can be any arbitrary key-value /// pairs seializable as JSON. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)] pub struct Claims { /// Registered claims, which are enumerated specifically. See [RegisteredClaims]. - #[serde(flatten)] + #[serde(default, flatten)] pub registered: RegisteredClaims, /// Custom claims, which are any arbitrary JSON objects. Custom claims must implement @@ -130,6 +147,18 @@ pub struct Claims { pub claims: C, } +impl Default for Claims +where + C: Default, +{ + fn default() -> Self { + Self { + registered: Default::default(), + claims: Default::default(), + } + } +} + impl Claims { /// Create a new set of claims. Claims can also be created by constructing the /// struct literal. diff --git a/src/jose/derive.rs b/src/jose/derive.rs index 12ff501..fa1c7bc 100644 --- a/src/jose/derive.rs +++ b/src/jose/derive.rs @@ -106,6 +106,13 @@ where KeyDerivation::Explicit(value) => DerivedKeyValue::Explicit(value), } } + + pub(super) fn explicit(value: Option) -> Self { + match value { + Some(value) => DerivedKeyValue::Explicit(value), + None => DerivedKeyValue::Omit, + } + } } impl ser::Serialize for DerivedKeyValue diff --git a/src/jose/mod.rs b/src/jose/mod.rs index 74b88d8..78ccb3a 100644 --- a/src/jose/mod.rs +++ b/src/jose/mod.rs @@ -16,7 +16,6 @@ use sha1::Sha1; use sha2::Sha256; use url::Url; -#[cfg(feature = "fmt")] use crate::base64data::Base64JSON; use crate::{algorithms::AlgorithmIdentifier, key::SerializeJWK}; @@ -54,6 +53,9 @@ pub enum HeaderError { #[error("invalid header type: {0}")] InvalidType(String), + #[error("invalid custom headers: {0} JSON serialized form must be an object or null")] + InvalidCustomHeaders(&'static str), + #[error("unable to serialize header value: {0}")] Serde(#[from] serde_json::Error), } @@ -69,31 +71,31 @@ pub enum HeaderError { #[derive(Debug, Clone, Serialize, Default, PartialEq, Eq, Deserialize)] struct RegisteredHeaderFields { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/jwk_set_url.md"))] - #[serde(rename = "jku", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "jku", skip_serializing_if = "Option::is_none")] jwk_set_url: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/type.md"))] - #[serde(rename = "typ", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "typ", skip_serializing_if = "Option::is_none")] r#type: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/key_id.md"))] - #[serde(rename = "kid", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "kid", skip_serializing_if = "Option::is_none")] key_id: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/certificate_url.md"))] - #[serde(rename = "x5u", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "x5u", skip_serializing_if = "Option::is_none")] pub certificate_url: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/certificate_chain.md"))] - #[serde(rename = "x5c", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "x5c", skip_serializing_if = "Option::is_none")] certificate_chain: Option>, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/content_type.md"))] - #[serde(rename = "cty", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "cty", skip_serializing_if = "Option::is_none")] content_type: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/critical.md"))] - #[serde(rename = "crit", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "crit", skip_serializing_if = "Option::is_none")] critical: Option>, } @@ -108,7 +110,7 @@ const REGISTERED_HEADER_KEYS: [&str; 11] = [ #[non_exhaustive] pub struct Header { #[serde(flatten)] - state: State, + pub(crate) state: State, /// The set of registered header parameters from [JWS][] and [JWA][]. /// @@ -160,7 +162,7 @@ impl Header { } /// Construct the JOSE header from the builder and signing key. - pub(crate) fn sign(self, key: &A::Key) -> Header> + pub(crate) fn into_signed_header(self, key: &A::Key) -> Header> where A: crate::algorithms::SigningAlgorithm, A::Key: Clone, @@ -178,21 +180,6 @@ impl Header { custom: self.custom, } } - - /// Access the JWK setting for the header. - pub fn jwk(&mut self) -> &mut KeyDerivation { - &mut self.state.key - } - - /// Access the JWK thumbprint setting for the header. - pub fn thumbprint(&mut self) -> &mut KeyDerivation> { - &mut self.state.thumbprint - } - - /// Access the JWK thumbprint setting for the header. - pub fn thumbprint_sha256(&mut self) -> &mut KeyDerivation> { - &mut self.state.thumbprint_sha256 - } } impl Header> @@ -200,16 +187,29 @@ where Key: SerializeJWK, { /// JWK signing algorithm in use. - pub fn algorithm(&self) -> &AlgorithmIdentifier { + pub(crate) fn algorithm(&self) -> &AlgorithmIdentifier { &self.state.algorithm } /// Render a signed JWK header into its rendered /// form, where the derived fields have been built /// as necessary. - pub fn render(self) -> Header { + /// + /// This will cause the header to "forget" what key and algorithm + /// are used in the signature, rendering those values literally into + /// the header. + pub(crate) fn into_rendered_header(self) -> Header + where + H: Serialize, + SignedHeader: HeaderState, + { + let headers = Base64JSON(&self) + .serialized_bytes() + .expect("valid header value"); + let state = RenderedHeader { - algorithm: self.state.algorithm, + raw: headers, + algorithm: *self.algorithm(), key: self.state.key.build(), thumbprint: self.state.thumbprint.build(), thumbprint_sha256: self.state.thumbprint_sha256.build(), @@ -224,16 +224,41 @@ where } impl Header { - pub fn algorithm(&self) -> &AlgorithmIdentifier { + /// JWK signing algorithm in use. + pub(crate) fn algorithm(&self) -> &AlgorithmIdentifier { &self.state.algorithm } + /// Convert a rendered header into a signed header, where the algorithms + /// must match. This is used when verifying a token, but cannot check + /// the validity of any other header fields. + /// + /// # Panics + /// + /// If the key algorithm does not match the header's algorithm. #[allow(unused_variables)] - pub(crate) fn verify(self, key: &A::Key) -> Result>, A::Error> + pub(crate) fn into_signed_header(self, key: &A::Key) -> Header> where A: crate::algorithms::VerifyAlgorithm, { - todo!("verify"); + if *self.algorithm() != A::IDENTIFIER { + panic!( + "algorithm mismatch: expected header to have {:?}, got {:?}", + A::IDENTIFIER, + self.algorithm() + ); + } + + Header { + state: SignedHeader { + algorithm: *self.algorithm(), + key: DerivedKeyValue::explicit(self.state.key), + thumbprint: DerivedKeyValue::explicit(self.state.thumbprint), + thumbprint_sha256: DerivedKeyValue::explicit(self.state.thumbprint_sha256), + }, + registered: self.registered, + custom: self.custom, + } } } @@ -242,6 +267,17 @@ where H: Serialize, State: HeaderState, { + /// Construct a JOSE header value. + /// + /// JOSE headers are JSON objects, so this method will serialize the header + /// into a JSON object, with keys lexicographically ordered. + /// + /// # Panics + /// + /// If a registered header were to conflict with a header owned by the type's + /// [HeaderState] implementation, this method will panic. That should not happen. + /// + /// If a custom header is not a JSON object or Null, this method will panic. pub(crate) fn value(&self) -> Result { // Re-using the parameters map here is important, because it will // alphabetize our keys, resulting in a consistent key order in rendered @@ -259,7 +295,7 @@ where } } Value::Null => {} - _ => panic!("registered headers are objects"), + _ => unreachable!("registered headers are objects"), } match custom { @@ -274,7 +310,7 @@ where } } Value::Null => {} - _ => panic!("custom headers are objects"), + _ => return Err(HeaderError::InvalidCustomHeaders(std::any::type_name::())), }; let mut map = serde_json::Map::new(); @@ -376,7 +412,7 @@ where { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] pub fn algorithm(&self) -> &AlgorithmIdentifier { - &self.header.state.algorithm + self.header.algorithm() } #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/json_web_key.md"))] @@ -398,7 +434,7 @@ where impl<'h, H> HeaderAccess<'h, H, RenderedHeader> { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] pub fn algorithm(&self) -> &AlgorithmIdentifier { - &self.header.state.algorithm + self.header.algorithm() } #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/json_web_key.md"))] @@ -501,7 +537,7 @@ where { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] pub fn algorithm(&self) -> &AlgorithmIdentifier { - &self.header.state.algorithm + self.header.algorithm() } #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/json_web_key.md"))] @@ -541,7 +577,3 @@ impl<'h, H> HeaderAccessMut<'h, H, RenderedHeader> { &mut self.header.state.thumbprint_sha256 } } - -/// Errors returned when verifying a header. -#[derive(Debug, thiserror::Error)] -pub enum VerifyHeaderError {} diff --git a/src/jose/rendered.rs b/src/jose/rendered.rs index 9f177d9..63573dc 100644 --- a/src/jose/rendered.rs +++ b/src/jose/rendered.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use serde::{Deserialize, Serialize}; use sha1::Sha1; use sha2::Sha256; @@ -13,8 +14,12 @@ use super::HeaderState; /// /// This is different from [super::SignedHeader] in that it contains the actual data, /// and not thd derivation, so the fields may be in inconsistent states. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct RenderedHeader { + /// The raw bytes of the header, as it was signed. + #[serde(skip)] + pub(crate) raw: Bytes, + #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] #[serde(rename = "alg")] pub(super) algorithm: AlgorithmIdentifier, diff --git a/src/token/formats.rs b/src/token/formats.rs index eca34d5..7d3f141 100644 --- a/src/token/formats.rs +++ b/src/token/formats.rs @@ -1,11 +1,15 @@ use std::fmt::Write; +use bytes::Bytes; +use serde::de::DeserializeOwned; use serde::{ser, Deserialize, Serialize}; -use super::{HasSignature, MaybeSigned}; +use super::{HasSignature, MaybeSigned, Unverified}; use super::{Payload, Token}; -use crate::base64data::{Base64Data, Base64JSON}; -use crate::jose::HeaderState; +use crate::algorithms::SignatureBytes; +use crate::base64data::{Base64Data, Base64JSON, DecodeError}; +use crate::jose::{HeaderState, RenderedHeader}; +use crate::Header; /// A token format that serializes the token as a compact string. #[derive(Debug, Default, Clone, Serialize, Deserialize)] @@ -66,6 +70,36 @@ pub enum TokenFormattingError { IO(#[from] std::fmt::Error), } +/// Error returned when a token cannot be parsed from a string. +/// +/// This error can occur when deserializing the header or payload +#[derive(Debug, thiserror::Error)] +pub enum TokenParseError { + /// Unable to find the header in the raw data. + #[error("missing header")] + MissingHeader, + + /// Unable to find the payload in the raw data. + #[error("missing payload")] + MissingPayload, + + /// Unable to find the signature in the raw data. + #[error("missing signature")] + MissingSignature, + + #[error(transparent)] + Utf8(#[from] std::str::Utf8Error), + + #[error(transparent)] + Base64(#[from] DecodeError), + + #[error(transparent)] + Json(#[from] serde_json::Error), + + #[error("unexpected JSON value for {0}: {1}")] + UnexpectedJSONValue(&'static str, serde_json::Value), +} + /// Trait for token formats, defining how they are serialized. pub trait TokenFormat { /// Render the token to the given writer. @@ -79,6 +113,13 @@ pub trait TokenFormat { S: HasSignature, ::Header: Serialize, ::HeaderState: HeaderState; + + /// Parse the token from a slice. + fn parse(data: Bytes) -> Result, Self>, TokenParseError> + where + P: DeserializeOwned, + H: DeserializeOwned, + Self: Sized; } impl TokenFormat for Compact { @@ -99,6 +140,49 @@ impl TokenFormat for Compact { write!(writer, "{}.{}.{}", header, payload, signature)?; Ok(()) } + + fn parse(data: Bytes) -> Result, Self>, TokenParseError> + where + P: DeserializeOwned, + H: DeserializeOwned, + Self: Sized, + { + let mut parts = data.splitn(3, |&b| b == b'.'); + let header = { + let b64_header = + std::str::from_utf8(parts.next().ok_or(TokenParseError::MissingHeader)?)?; + + let wrapped_header = Base64JSON::>::parse(b64_header)?; + let mut header = wrapped_header.data; + header.state.raw = wrapped_header.bytes; + header + }; + + let (payload, raw_payload) = { + let b64_payload: &str = + std::str::from_utf8(parts.next().ok_or(TokenParseError::MissingPayload)?)?; + let payload: Payload

= Payload::parse(b64_payload)?; + let raw_payload: Vec = b64_payload.as_bytes().into(); + (payload, raw_payload) + }; + + let signature = { + let signature = parts.next().ok_or(TokenParseError::MissingSignature)?; + let signature: Base64Data = + Base64Data::parse(std::str::from_utf8(signature)?)?; + signature + }; + + Ok(Token { + payload, + state: Unverified { + header, + signature, + payload: raw_payload.into(), + }, + fmt: Compact, + }) + } } #[derive(Debug, Serialize)] @@ -111,7 +195,7 @@ struct FlatToken<'t, P, U> { impl TokenFormat for FlatUnprotected where - U: Serialize, + U: Serialize + DeserializeOwned, { fn render( &self, @@ -139,6 +223,34 @@ where Ok(()) } + + fn parse(data: Bytes) -> Result, Self>, TokenParseError> + where + P: DeserializeOwned, + H: DeserializeOwned, + Self: Sized, + { + let value: serde_json::Value = serde_json::from_slice(&data)?; + let serde_json::Value::Object(mut object): serde_json::Value = value else { + return Err(TokenParseError::UnexpectedJSONValue("token", value)); + }; + + let Token { payload, state, .. } = parse_flat_common_values(&mut object)?; + + let unprotected = { + let unprotected = object + .remove("unprotected") + .ok_or(TokenParseError::MissingHeader)?; + let unprotected: U = serde_json::from_value(unprotected)?; + unprotected + }; + + Ok(Token { + payload, + state, + fmt: FlatUnprotected { unprotected }, + }) + } } impl Serialize for Token> @@ -146,7 +258,7 @@ where S: HasSignature, ::Header: Serialize, ::HeaderState: HeaderState, - U: Serialize, + U: Serialize + DeserializeOwned, P: Serialize, { fn serialize(&self, serializer: Ser) -> Result @@ -204,6 +316,74 @@ impl TokenFormat for Flat { Ok(()) } + + fn parse(data: Bytes) -> Result, Self>, TokenParseError> + where + P: DeserializeOwned, + H: DeserializeOwned, + Self: Sized, + { + let value: serde_json::Value = serde_json::from_slice(&data)?; + let serde_json::Value::Object(mut object): serde_json::Value = value else { + return Err(TokenParseError::UnexpectedJSONValue("token", value)); + }; + parse_flat_common_values(&mut object) + } +} + +fn parse_flat_common_values( + object: &mut serde_json::Map, +) -> Result, Flat>, TokenParseError> +where + P: DeserializeOwned, + H: DeserializeOwned, +{ + let header = { + let protected = object + .get("protected") + .ok_or(TokenParseError::MissingHeader)?; + let protected = + Base64JSON::>::parse(protected.as_str().ok_or_else( + || TokenParseError::UnexpectedJSONValue("header", protected.clone()), + )?)?; + let mut header = protected.data; + header.state.raw = protected.bytes; + header + }; + + let (payload, raw_payload) = { + let value_payload = object + .remove("payload") + .ok_or(TokenParseError::MissingPayload)?; + let b64_payload = value_payload.as_str().ok_or_else(|| { + TokenParseError::UnexpectedJSONValue("payload", value_payload.clone()) + })?; + + let payload: Payload

= Payload::parse(b64_payload)?; + let raw_payload: Vec = b64_payload.as_bytes().into(); + (payload, raw_payload) + }; + + let signature = { + let signature = object + .remove("signature") + .ok_or(TokenParseError::MissingSignature)?; + let signature: Base64Data = + Base64Data::parse(signature.as_str().ok_or_else(|| { + TokenParseError::UnexpectedJSONValue("signature", signature.clone()) + })?)?; + signature + }; + + Ok(Token { + payload, + state: Unverified { + header, + signature, + payload: raw_payload.into(), + }, + fmt: Flat, + }) } impl Serialize for Token diff --git a/src/token/mod.rs b/src/token/mod.rs index 8b004ac..b3c993c 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -8,17 +8,22 @@ //! //! [RFC7519]: https://tools.ietf.org/html/rfc7519 -use std::fmt::Write; use std::marker::PhantomData; +use std::{fmt::Write, str::FromStr}; use base64ct::Encoding; -use serde::{de, ser, Deserialize, Serialize}; +use bytes::Bytes; +use serde::{ + de::{self, DeserializeOwned}, + ser, Deserialize, Serialize, +}; +use crate::algorithms::VerifyAlgorithm; #[cfg(feature = "fmt")] use crate::fmt; use crate::{ algorithms::{AlgorithmIdentifier, SigningAlgorithm}, - base64data::{Base64Data, Base64JSON}, + base64data::{Base64Data, Base64JSON, DecodeError}, jose::{HeaderAccess, HeaderAccessMut, HeaderState}, Header, }; @@ -26,6 +31,7 @@ use crate::{ mod formats; mod state; +use self::formats::TokenParseError; pub use self::formats::{Compact, Flat, FlatUnprotected, TokenFormat, TokenFormattingError}; pub use self::state::{HasSignature, MaybeSigned, Signed, Unsigned, Unverified, Verified}; @@ -67,6 +73,27 @@ where Payload::Empty => Ok("".to_owned()), } } + + fn serialized_bytes(&self) -> Result { + match self { + Payload::Json(data) => data.serialized_bytes(), + Payload::Empty => Ok(Bytes::new()), + } + } +} + +impl

Payload

+where + P: DeserializeOwned, +{ + fn parse(value: &str) -> Result { + if value.is_empty() { + return Ok(Payload::Empty); + } + + let parsed = Base64JSON::

::parse(value)?; + Ok(Payload::Json(parsed.data.into())) + } } impl

From

for Payload

{ @@ -149,10 +176,12 @@ where /// let token = Token::compact((), ()); /// ``` /// -/// This token will have no payload, and no custom headers, but it is still usable: +/// This token will have no payload, and no custom headers. #[cfg_attr( feature = "fmt", doc = r#" +To view a debug representation of the token, use the [`fmt::JWTFormat`] trait: + ``` # use jaws::token::Token; # let token = Token::compact((), ()); @@ -298,6 +327,18 @@ where } } +impl Token, Fmt> +where + Fmt: TokenFormat, +{ + pub fn payload(&self) -> Option<&P> { + match &self.payload { + Payload::Json(data) => Some(data.as_ref()), + Payload::Empty => None, + } + } +} + impl Token, Fmt> where H: Serialize, @@ -319,7 +360,7 @@ where A::Key: Clone, // A::Signature: Serialize, { - let header = self.state.header.sign::(algorithm.key()); + let header = self.state.header.into_signed_header::(algorithm.key()); let headers = Base64JSON(&header).serialized_value()?; let payload = self.payload.serialized_value()?; let signature = algorithm @@ -348,7 +389,7 @@ where #[allow(clippy::type_complexity)] pub fn verify( self, - algorithm: A, + algorithm: &A, ) -> Result, Fmt>, TokenVerifyingError> where A: crate::algorithms::VerifyAlgorithm, @@ -364,16 +405,16 @@ where } let signature = &self.state.signature; - let header = self - .state - .header - .verify::(algorithm.key()) - .map_err(TokenVerifyingError::Verify)?; - let headers = Base64JSON(&header).serialized_value()?; - let payload = self.payload.serialized_value()?; let signature = algorithm - .verify(&headers, &payload, signature.as_ref()) + .verify( + &self.state.header.state.raw, + &self.state.payload, + signature.as_ref(), + ) .map_err(TokenVerifyingError::Verify)?; + + let header = self.state.header.into_signed_header::(algorithm.key()); + Ok(Token { payload: self.payload, state: Verified { header, signature }, @@ -382,20 +423,41 @@ where } } +impl FromStr for Token, Fmt> +where + P: DeserializeOwned, + H: DeserializeOwned, + Fmt: TokenFormat, +{ + type Err = TokenParseError; + + fn from_str(s: &str) -> Result { + Fmt::parse(Bytes::from(s.to_owned())) + } +} + impl Token, Fmt> where Fmt: TokenFormat, Alg: SigningAlgorithm, + Alg::Key: Clone, + H: Serialize, + P: Serialize, { /// Transition the token back into an unverified state. /// /// This method consumes the token and returns a new one, which still includes the signature /// but which is no longer considered verified. pub fn unverify(self) -> Token, Fmt> { + let payload = self + .payload + .serialized_bytes() + .expect("valid payload bytes"); Token { payload: self.payload, state: Unverified { - header: self.state.header.render(), + payload, + header: self.state.header.into_rendered_header(), signature: Base64Data(self.state.signature.as_ref().to_owned().into()), }, fmt: self.fmt, @@ -403,12 +465,67 @@ where } } +impl Token, Fmt> +where + Fmt: TokenFormat, + Alg: SigningAlgorithm, +{ + pub fn payload(&self) -> Option<&P> { + match &self.payload { + Payload::Json(data) => Some(data.as_ref()), + Payload::Empty => None, + } + } +} + +impl Token, Fmt> +where + Fmt: TokenFormat, + Alg: VerifyAlgorithm, + Alg::Key: Clone, + H: Serialize, + P: Serialize, +{ + /// Transition the token back into an unverified state. + /// + /// This method consumes the token and returns a new one, which still includes the signature + /// but which is no longer considered verified. + pub fn unverify(self) -> Token, Fmt> { + let payload = self + .payload + .serialized_bytes() + .expect("valid payload bytes"); + Token { + payload: self.payload, + state: Unverified { + payload, + header: self.state.header.into_rendered_header(), + signature: Base64Data(self.state.signature.as_ref().to_owned().into()), + }, + fmt: self.fmt, + } + } +} + +impl Token, Fmt> +where + Fmt: TokenFormat, + Alg: VerifyAlgorithm, +{ + pub fn payload(&self) -> Option<&P> { + match &self.payload { + Payload::Json(data) => Some(data.as_ref()), + Payload::Empty => None, + } + } +} + #[cfg(feature = "fmt")] impl fmt::JWTFormat for Token where S: HasSignature, ::Header: Serialize, - ::HeaderState: Serialize + HeaderState, + ::HeaderState: HeaderState, P: Serialize, Fmt: TokenFormat, { @@ -493,7 +610,7 @@ pub enum TokenSigningError { } #[cfg(all(test, feature = "rsa"))] -mod test { +mod test_rsa { use super::*; use crate::claims::Claims; @@ -502,6 +619,8 @@ mod test { use serde_json::json; use sha2::Sha256; + use signature::Keypair; + use crate::key::jwk_reader::rsa; fn strip_whitespace(s: &str) -> String { @@ -607,5 +726,100 @@ mod test { ) ) } + + let algorithm = algorithm.verifying_key(); + + signed.unverify().verify(&algorithm).unwrap(); + } +} + +#[cfg(all(test, feature = "ecdsa", feature = "p256"))] +mod test_ecdsa { + use super::*; + + use base64ct::Encoding; + use elliptic_curve::{FieldBytes, SecretKey}; + use serde_json::json; + use zeroize::Zeroize; + + fn strip_whitespace(s: &str) -> String { + s.chars().filter(|c| !c.is_whitespace()).collect() + } + + fn ecdsa(jwk: &serde_json::Value) -> SecretKey { + let d_b64 = strip_whitespace(jwk["d"].as_str().unwrap()); + let mut d_bytes = FieldBytes::::default(); + base64ct::Base64UrlUnpadded::decode(&d_b64, &mut d_bytes).unwrap(); + + let key = SecretKey::from_slice(&d_bytes).unwrap(); + d_bytes.zeroize(); + key + } + + #[test] + fn rfc7515_example_a3() { + let pkey = &json!({ + "kty":"EC", + "crv":"P-256", + "x":"f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU", + "y":"x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0", + "d":"jpsQnnGQmL-YBIffH1136cspYG6-0iY7X1fCE9-E9LI" + }); + + let key = ecdsa(pkey); + + let token = Token::compact((), "This is a signed message"); + + let signed = token.sign(&key).unwrap(); + + let verifying_key: ecdsa::VerifyingKey<_> = key.public_key().into(); + + let verified = signed.unverify().verify(&verifying_key).unwrap(); + + assert_eq!(verified.payload(), Some(&"This is a signed message")); + } +} + +#[cfg(all(test, feature = "hmac"))] +mod test_hmac { + use crate::algorithms::hmac::{Hmac, HmacKey}; + + use super::*; + + use base64ct::Encoding; + use serde_json::json; + use sha2::Sha256; + + fn strip_whitespace(s: &str) -> String { + s.chars().filter(|c| !c.is_whitespace()).collect() + } + + #[test] + fn rfc7515_example_a1() { + let pkey = &json!({ + "kty":"oct", + "k":"AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75 + aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow" + } + ); + + let key_data = strip_whitespace(pkey["k"].as_str().unwrap()); + + let decoded_len = 3 * key_data.len() / 4; + + let mut key = HmacKey::with_capacity(decoded_len); + key.resize(decoded_len, 0); + + base64ct::Base64UrlUnpadded::decode(&key_data, key.as_mut()).unwrap(); + + let algorithm: Hmac = Hmac::new(key); + + let token = Token::compact((), "This is an HMAC'd message"); + + let signed = token.sign(&algorithm).unwrap(); + + let verified = signed.unverify().verify(&algorithm).unwrap(); + + assert_eq!(verified.payload(), Some(&"This is an HMAC'd message")); } } diff --git a/src/token/state.rs b/src/token/state.rs index 9b1db01..3112ac6 100644 --- a/src/token/state.rs +++ b/src/token/state.rs @@ -1,7 +1,8 @@ -use serde::{Deserialize, Serialize}; +use bytes::Bytes; +use serde::Serialize; use crate::{ - algorithms::{Signature as SignatureBytes, SigningAlgorithm, VerifyAlgorithm}, + algorithms::{SignatureBytes, SigningAlgorithm, VerifyAlgorithm}, base64data::Base64Data, jose, }; @@ -183,12 +184,9 @@ where /// /// This state indicates that we have recieved the token from elsewhere, and /// many fields could be in inconsistnet states. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound( - serialize = "H: Serialize", - deserialize = "H: for<'deh> Deserialize<'deh>" -))] +#[derive(Debug, Clone)] pub struct Unverified { + pub(super) payload: Bytes, pub(super) header: jose::Header, pub(super) signature: Base64Data, }