From 8fc2acfb3fe0b0e1a0039a71640fdd520577f7b2 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Fri, 24 Nov 2023 17:10:24 +0000 Subject: [PATCH 1/3] Signature verification support Adds signature verification for tokens and associated header and algorithm methods. --- Cargo.toml | 1 + src/algorithms/ecdsa.rs | 68 +++++++++++++++++++++++++++++++++++++++-- src/algorithms/hmac.rs | 14 +++++---- src/algorithms/mod.rs | 19 ++++++++++-- src/algorithms/rsa.rs | 35 ++++++++++++++++++++- src/base64data.rs | 5 +++ src/jose/derive.rs | 7 +++++ src/jose/mod.rs | 37 +++++++++++++++++++--- src/jose/rendered.rs | 3 ++ src/token/mod.rs | 33 ++++++++++++++------ src/token/state.rs | 1 + 11 files changed, 199 insertions(+), 24 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 33a0328..737c4af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ categories = [ [dependencies] base64ct = { version = "1.6.0", features = ["std"] } +bytes = { version = "1.5" } chrono = { version = "0.4.24", features = ["serde"] } elliptic-curve = { version = "0.13.4", features = [ "pkcs8", diff --git a/src/algorithms/ecdsa.rs b/src/algorithms/ecdsa.rs index 7d76f3c..75a1235 100644 --- a/src/algorithms/ecdsa.rs +++ b/src/algorithms/ecdsa.rs @@ -74,7 +74,9 @@ use ::ecdsa::{ PrimeCurve, SignatureSize, }; use base64ct::Encoding; +use bytes::BytesMut; use digest::generic_array::ArrayLength; +use ecdsa::{hazmat::VerifyPrimitive, Signature, VerifyingKey}; use elliptic_curve::{ ops::Invert, sec1::{Coordinates, FromEncodedPoint, ModulusSize, ToEncodedPoint}, @@ -92,6 +94,7 @@ pub use p384::NistP384; #[cfg(feature = "p521")] pub use p521::NistP521; +use signature::Verifier; impl crate::key::JWKeyType for PublicKey where @@ -109,7 +112,7 @@ where FieldBytesSize: ModulusSize, { fn parameters(&self) -> Vec<(String, serde_json::Value)> { - let mut params = Vec::with_capacity(2); + let mut params = Vec::with_capacity(3); params.push(( "crv".to_owned(), @@ -144,6 +147,31 @@ where self.public_key().parameters() } } + +impl crate::key::SerializeJWK for ecdsa::VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, +{ + fn parameters(&self) -> Vec<(String, serde_json::Value)> { + PublicKey::::from(self).parameters() + } +} + +impl crate::key::JWKeyType for ecdsa::VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, +{ + const KEY_TYPE: &'static str = "EC"; +} + impl crate::key::SerializeJWK for ecdsa::SigningKey where C: PrimeCurve + CurveArithmetic + JwkParameters, @@ -153,7 +181,7 @@ where FieldBytesSize: ModulusSize, { fn parameters(&self) -> Vec<(String, serde_json::Value)> { - todo!() + self.verifying_key().parameters() } } @@ -223,6 +251,42 @@ where } } +impl super::VerifyAlgorithm for VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters + ecdsa::hazmat::DigestPrimitive, + ::AffinePoint: VerifyPrimitive, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + MaxSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, + VerifyingKey: super::Algorithm>, + as Add>::Output: Add + ArrayLength, +{ + type Error = ecdsa::Error; + type Key = VerifyingKey; + + fn verify( + &self, + header: &[u8], + payload: &[u8], + signature: &[u8], + ) -> Result { + let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); + message.extend_from_slice(header); + message.extend_from_slice(b"."); + message.extend_from_slice(payload); + + let signature = ecdsa::Signature::try_from(signature)?; + >>::verify(self, message.as_ref(), &signature)?; + Ok(signature.into()) + } + + fn key(&self) -> &Self::Key { + self + } +} + #[cfg(all(test, feature = "p256"))] mod test { diff --git a/src/algorithms/hmac.rs b/src/algorithms/hmac.rs index fd9b616..4315902 100644 --- a/src/algorithms/hmac.rs +++ b/src/algorithms/hmac.rs @@ -2,7 +2,7 @@ //! //! Based on the [hmac](https://crates.io/crates/hmac) crate. -use std::convert::Infallible; +use std::{convert::Infallible, marker::PhantomData}; use base64ct::Encoding; use digest::Mac; @@ -81,7 +81,7 @@ where D: digest::Digest + digest::core_api::BlockSizeUser, { key: HmacKey, - digest: hmac::SimpleHmac, + _digest: PhantomData, } impl Hmac @@ -93,8 +93,10 @@ where /// /// Signing keys are arbitrary bytes. pub fn new(key: HmacKey) -> Self { - let digest = SimpleHmac::new_from_slice(key.as_ref()).expect("Valid key"); - Self { key, digest } + Self { + key, + _digest: PhantomData, + } } /// Reference to the HMAC key. @@ -133,8 +135,8 @@ where fn sign(&self, header: &str, payload: &str) -> Result { // Create a new, one-shot digest for this signature. - let mut digest = self.digest.clone(); - digest.reset(); + let mut digest: SimpleHmac = + SimpleHmac::new_from_slice(self.key.as_ref()).expect("Valid key"); let message = format!("{}.{}", header, payload); digest.update(message.as_bytes()); Ok(digest.finalize().into_bytes()) diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 8b017e7..aa561ba 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -69,6 +69,21 @@ pub enum AlgorithmIdentifier { None, } +impl AlgorithmIdentifier { + /// Return whether this algorithm is available for signing. + pub fn available(&self) -> bool { + match self { + Self::None => true, + + Self::HS256 | Self::HS384 | Self::HS512 => cfg!(feature = "hmac"), + Self::RS256 | Self::RS384 | Self::RS512 => cfg!(feature = "rsa"), + Self::ES256 | Self::ES384 | Self::ES512 => cfg!(feature = "ecdsa"), + Self::EdDSA => cfg!(feature = "ed25519"), + Self::PS256 | Self::PS384 | Self::PS512 => cfg!(feature = "rsa"), + } + } +} + /// A trait to associate an alogritm identifier with an algorithm. /// /// Algorithm identifiers are used in JWS and JWE to indicate how a token is signed or encrypted. @@ -118,8 +133,8 @@ pub trait VerifyAlgorithm: Algorithm { /// and payload. fn verify( &self, - header: &str, - payload: &str, + header: &[u8], + payload: &[u8], signature: &[u8], ) -> Result; diff --git a/src/algorithms/rsa.rs b/src/algorithms/rsa.rs index 78e22d2..956f6b1 100644 --- a/src/algorithms/rsa.rs +++ b/src/algorithms/rsa.rs @@ -16,7 +16,8 @@ //! This algorithm is used to sign and verify JSON Web Tokens using the RSASSA-PSS. use base64ct::{Base64UrlUnpadded, Encoding}; -use rsa::pkcs1v15::SigningKey; +use bytes::BytesMut; +use rsa::pkcs1v15::{SigningKey, VerifyingKey}; use rsa::rand_core::OsRng; use rsa::signature::RandomizedSigner; use rsa::PublicKeyParts; @@ -69,6 +70,38 @@ where } } +impl super::VerifyAlgorithm for VerifyingKey +where + D: digest::Digest, + VerifyingKey: super::Algorithm, +{ + type Error = signature::Error; + + type Key = rsa::RsaPublicKey; + + fn verify( + &self, + header: &[u8], + payload: &[u8], + signature: &[u8], + ) -> Result { + use rsa::signature::Verifier; + let signature = rsa::pkcs1v15::Signature::try_from(signature).unwrap(); + + let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); + message.extend_from_slice(header); + message.extend_from_slice(b"."); + message.extend_from_slice(payload); + + >::verify(self, message.as_ref(), &signature)?; + Ok(signature) + } + + fn key(&self) -> &Self::Key { + self.as_ref() + } +} + impl super::Algorithm for RsaPkcs1v15 { const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS256; type Signature = rsa::pkcs1v15::Signature; diff --git a/src/base64data.rs b/src/base64data.rs index d747358..76ed48f 100644 --- a/src/base64data.rs +++ b/src/base64data.rs @@ -120,6 +120,11 @@ where let inner = serde_json::to_vec(&self.0)?; Ok(base64ct::Base64UrlUnpadded::encode_string(&inner)) } + + pub(crate) fn serialized_bytes(&self) -> Result, serde_json::Error> { + let inner = serde_json::to_vec(&self.0)?; + Ok(inner.into_boxed_slice()) + } } impl AsRef for Base64JSON { diff --git a/src/jose/derive.rs b/src/jose/derive.rs index 12ff501..fa1c7bc 100644 --- a/src/jose/derive.rs +++ b/src/jose/derive.rs @@ -106,6 +106,13 @@ where KeyDerivation::Explicit(value) => DerivedKeyValue::Explicit(value), } } + + pub(super) fn explicit(value: Option) -> Self { + match value { + Some(value) => DerivedKeyValue::Explicit(value), + None => DerivedKeyValue::Omit, + } + } } impl ser::Serialize for DerivedKeyValue diff --git a/src/jose/mod.rs b/src/jose/mod.rs index 74b88d8..cf227ba 100644 --- a/src/jose/mod.rs +++ b/src/jose/mod.rs @@ -18,6 +18,7 @@ use url::Url; #[cfg(feature = "fmt")] use crate::base64data::Base64JSON; +use crate::token::TokenVerifyingError; use crate::{algorithms::AlgorithmIdentifier, key::SerializeJWK}; #[cfg(feature = "fmt")] @@ -108,7 +109,7 @@ const REGISTERED_HEADER_KEYS: [&str; 11] = [ #[non_exhaustive] pub struct Header { #[serde(flatten)] - state: State, + pub(crate) state: State, /// The set of registered header parameters from [JWS][] and [JWA][]. /// @@ -207,8 +208,17 @@ where /// Render a signed JWK header into its rendered /// form, where the derived fields have been built /// as necessary. - pub fn render(self) -> Header { + pub fn render(self) -> Header + where + H: Serialize, + SignedHeader: HeaderState, + { + let headers = Base64JSON(&self) + .serialized_bytes() + .expect("valid header value"); + let state = RenderedHeader { + raw: headers, algorithm: self.state.algorithm, key: self.state.key.build(), thumbprint: self.state.thumbprint.build(), @@ -229,11 +239,30 @@ impl Header { } #[allow(unused_variables)] - pub(crate) fn verify(self, key: &A::Key) -> Result>, A::Error> + pub(crate) fn verify( + self, + key: &A::Key, + ) -> Result>, TokenVerifyingError> where A: crate::algorithms::VerifyAlgorithm, { - todo!("verify"); + // This may need to only verify that the algorithm header matches the key algorithm. + if A::IDENTIFIER != self.state.algorithm { + return Err(TokenVerifyingError::Algorithm( + A::IDENTIFIER, + self.state.algorithm, + )); + } + Ok(Header { + state: SignedHeader { + algorithm: self.state.algorithm, + key: DerivedKeyValue::explicit(self.state.key), + thumbprint: DerivedKeyValue::explicit(self.state.thumbprint), + thumbprint_sha256: DerivedKeyValue::explicit(self.state.thumbprint_sha256), + }, + registered: self.registered, + custom: self.custom, + }) } } diff --git a/src/jose/rendered.rs b/src/jose/rendered.rs index 9f177d9..7bdc09b 100644 --- a/src/jose/rendered.rs +++ b/src/jose/rendered.rs @@ -15,6 +15,9 @@ use super::HeaderState; /// and not thd derivation, so the fields may be in inconsistent states. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RenderedHeader { + /// The raw bytes of the header, as it was signed. + pub(crate) raw: Box<[u8]>, + #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] #[serde(rename = "alg")] pub(super) algorithm: AlgorithmIdentifier, diff --git a/src/token/mod.rs b/src/token/mod.rs index 8b004ac..c582d8e 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -67,6 +67,13 @@ where Payload::Empty => Ok("".to_owned()), } } + + fn serialized_bytes(&self) -> Result, serde_json::Error> { + match self { + Payload::Json(data) => data.serialized_bytes(), + Payload::Empty => Ok(Box::new([])), + } + } } impl

From

for Payload

{ @@ -364,16 +371,16 @@ where } let signature = &self.state.signature; - let header = self - .state - .header - .verify::(algorithm.key()) - .map_err(TokenVerifyingError::Verify)?; - let headers = Base64JSON(&header).serialized_value()?; - let payload = self.payload.serialized_value()?; let signature = algorithm - .verify(&headers, &payload, signature.as_ref()) + .verify( + &self.state.header.state.raw, + &self.state.payload, + signature.as_ref(), + ) .map_err(TokenVerifyingError::Verify)?; + + let header = self.state.header.verify::(algorithm.key())?; + Ok(Token { payload: self.payload, state: Verified { header, signature }, @@ -386,15 +393,23 @@ impl Token, Fmt> where Fmt: TokenFormat, Alg: SigningAlgorithm, + Alg::Key: Clone, + H: Serialize, + P: Serialize, { /// Transition the token back into an unverified state. /// /// This method consumes the token and returns a new one, which still includes the signature /// but which is no longer considered verified. pub fn unverify(self) -> Token, Fmt> { + let payload = self + .payload + .serialized_bytes() + .expect("valid payload bytes"); Token { payload: self.payload, state: Unverified { + payload, header: self.state.header.render(), signature: Base64Data(self.state.signature.as_ref().to_owned().into()), }, @@ -408,7 +423,7 @@ impl fmt::JWTFormat for Token where S: HasSignature, ::Header: Serialize, - ::HeaderState: Serialize + HeaderState, + ::HeaderState: HeaderState, P: Serialize, Fmt: TokenFormat, { diff --git a/src/token/state.rs b/src/token/state.rs index 9b1db01..6bb1e9c 100644 --- a/src/token/state.rs +++ b/src/token/state.rs @@ -189,6 +189,7 @@ where deserialize = "H: for<'deh> Deserialize<'deh>" ))] pub struct Unverified { + pub(super) payload: Box<[u8]>, pub(super) header: jose::Header, pub(super) signature: Base64Data, } From fa19ff1b0c30e7cbf96e55d4873a14cdfb3f4f4b Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Fri, 24 Nov 2023 20:18:49 +0000 Subject: [PATCH 2/3] Add roundtrip support for tokens --- README.md | 52 ++++++++-- examples/rfc7515a2.rs | 52 ++++++++-- src/algorithms/ecdsa.rs | 18 ++++ src/algorithms/hmac.rs | 38 +++++++ src/algorithms/mod.rs | 16 +-- src/algorithms/rsa.rs | 12 ++- src/base64data.rs | 67 ++++++++++++- src/claims.rs | 43 ++++++-- src/jose/mod.rs | 102 ++++++++++--------- src/jose/rendered.rs | 6 +- src/token/formats.rs | 190 ++++++++++++++++++++++++++++++++++- src/token/mod.rs | 217 ++++++++++++++++++++++++++++++++++++++-- src/token/state.rs | 13 +-- 13 files changed, 710 insertions(+), 116 deletions(-) diff --git a/README.md b/README.md index 2ea01e1..0a005cf 100644 --- a/README.md +++ b/README.md @@ -72,17 +72,25 @@ stateDiagram-v2 To create a simple JWT, you'll need to provide an encryption key. This example uses the RSA encrption key defined in Appendix A.2 of [RFC 7515][JWS], don't re use it! -This example is reproduced from [`examples/rfc7515a2.rs`][/examples/rfc7515a2.rs] in the repository, -and can be run with `cargo run --example rfc7515a2`. +This example is reproduced from [`examples/rfc7515a2.rs`](./examples/rfc7515a2.rs) in the repository, +and can be run with `cargo run --example rfc7515-a2`. ```rust +use jaws::Compact; + // JAWS provides JWT format for printing JWTs in a style similar to the example above, // which is directly inspired by the way the ACME standard shows JWTs. use jaws::JWTFormat; // JAWS provides a single token type which is generic over the state of the token. +// The states are defined in the `state` module, and are used to track the +// signing and verification status. use jaws::Token; +// The unverified token state, used like `Token<.., Unverified<..>, ..>`. +// It is generic over the type of the custom header parameters. +use jaws::token::Unverified; + // JAWS provides type-safe support for JWT claims. use jaws::{Claims, RegisteredClaims}; @@ -94,10 +102,8 @@ use rsa::pkcs8::DecodePrivateKey; // function, so we get it here from the `sha2` crate in the RustCrypto suite. use sha2::Sha256; -// JAWS provides thin algorithm wrappers for algorithms which accept -// parameters beyond just the encryption or singing key. For example, the `RS256` -// algorithm accepts a hash function, but is otherwise identical to the other -// `RS*` hash functions. +// This is an alias for the RSA PKCS#1 v1.5 signing algorithm, which is +// implemented in the rsa crate as `rsa::pkcs1v15::SigningKey`. use jaws::algorithms::rsa::RsaPkcs1v15; // Using serde_json allows us to quickly construct a serializable payload, @@ -143,9 +149,21 @@ fn main() -> Result<(), Box> { // we provide the `typ` header, which is optional in the JWT spec. *token.header_mut().r#type() = Some("JWT".to_string()); + println!("Initial JWT"); + + // Initially the JWT has no defined signature: + println!("JWT:"); + println!("{}", token.formatted()); + // Sign the token with the algorithm, and print the result. let signed = token.sign(&alg).unwrap(); + println!("Signed JWT"); + + println!("JWT:"); + println!("{}", signed.formatted()); + println!("Token: {}", signed.rendered().unwrap()); + // We can't modify the token after signing it (that would change the signature) // but we can access fields and read from them: println!( @@ -154,9 +172,27 @@ fn main() -> Result<(), Box> { signed.header().algorithm(), ); - println!("Token: {}", signed.rendered().unwrap()); + // We can also verify tokens. + let token: Token, Unverified<()>, Compact> = + signed.rendered().unwrap().parse().unwrap(); + + println!("Parsed JWT"); + + // Unverified tokens can be printed for debugging, but there is deliberately + // no access to the payload, only to the header fields. println!("JWT:"); - println!("{}", signed.formatted()); + println!("{}", token.formatted()); + + // We can't access the claims until we verify the token. + let verified = token.verify(&alg).unwrap(); + + println!("Verified JWT"); + println!("JWT:"); + println!("{}", verified.formatted()); + println!( + "Payload: \n{}", + serde_json::to_string_pretty(&verified.payload()).unwrap() + ); Ok(()) } diff --git a/examples/rfc7515a2.rs b/examples/rfc7515a2.rs index ffc1b3c..9b78d88 100644 --- a/examples/rfc7515a2.rs +++ b/examples/rfc7515a2.rs @@ -1,12 +1,18 @@ use jaws::Compact; + // JAWS provides JWT format for printing JWTs in a style similar to the example above, // which is directly inspired by the way the ACME standard shows JWTs. use jaws::JWTFormat; -// JAWS provides strongly typed support for tokens, so we can only build an UnsignedToken, -// which we can sign to create a SignedToken or a plain Token. +// JAWS provides a single token type which is generic over the state of the token. +// The states are defined in the `state` module, and are used to track the +// signing and verification status. use jaws::Token; +// The unverified token state, used like `Token<.., Unverified<..>, ..>`. +// It is generic over the type of the custom header parameters. +use jaws::token::Unverified; + // JAWS provides type-safe support for JWT claims. use jaws::{Claims, RegisteredClaims}; @@ -18,10 +24,8 @@ use rsa::pkcs8::DecodePrivateKey; // function, so we get it here from the `sha2` crate in the RustCrypto suite. use sha2::Sha256; -// JAWS provides thin algorithm wrappers for algorithms which accept -// parameters beyond just the encryption or singing key. For example, the `RS256` -// algorithm accepts a hash function, but is otherwise identical to the other -// `RS*` hash functions. +// This is an alias for the RSA PKCS#1 v1.5 signing algorithm, which is +// implemented in the rsa crate as `rsa::pkcs1v15::SigningKey`. use jaws::algorithms::rsa::RsaPkcs1v15; // Using serde_json allows us to quickly construct a serializable payload, @@ -62,14 +66,26 @@ fn main() -> Result<(), Box> { // The unit type can be used here because it implements [serde::Serialize], // but a custom type could be passed if we wanted to have custom header // fields. - let mut token = Token::new((), claims, Compact); + let mut token = Token::compact((), claims); // We can modify the headers freely before signing the JWT. In this case, // we provide the `typ` header, which is optional in the JWT spec. *token.header_mut().r#type() = Some("JWT".to_string()); + println!("Initial JWT"); + + // Initially the JWT has no defined signature: + println!("JWT:"); + println!("{}", token.formatted()); + // Sign the token with the algorithm, and print the result. let signed = token.sign(&alg).unwrap(); + println!("Signed JWT"); + + println!("JWT:"); + println!("{}", signed.formatted()); + println!("Token: {}", signed.rendered().unwrap()); + // We can't modify the token after signing it (that would change the signature) // but we can access fields and read from them: println!( @@ -78,9 +94,27 @@ fn main() -> Result<(), Box> { signed.header().algorithm(), ); - println!("Token: {}", signed.rendered().unwrap()); + // We can also verify tokens. + let token: Token, Unverified<()>, Compact> = + signed.rendered().unwrap().parse().unwrap(); + + println!("Parsed JWT"); + + // Unverified tokens can be printed for debugging, but there is deliberately + // no access to the payload, only to the header fields. println!("JWT:"); - println!("{}", signed.formatted()); + println!("{}", token.formatted()); + + // We can't access the claims until we verify the token. + let verified = token.verify(&alg).unwrap(); + + println!("Verified JWT"); + println!("JWT:"); + println!("{}", verified.formatted()); + println!( + "Payload: \n{}", + serde_json::to_string_pretty(&verified.payload()).unwrap() + ); Ok(()) } diff --git a/src/algorithms/ecdsa.rs b/src/algorithms/ecdsa.rs index 75a1235..3328553 100644 --- a/src/algorithms/ecdsa.rs +++ b/src/algorithms/ecdsa.rs @@ -287,6 +287,24 @@ where } } +#[cfg(feature = "p256")] +impl super::Algorithm for VerifyingKey { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::ES256; + type Signature = ecdsa::SignatureBytes; +} + +#[cfg(feature = "p384")] +impl super::Algorithm for VerifyingKey { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::ES384; + type Signature = ecdsa::SignatureBytes; +} + +#[cfg(feature = "p521")] +impl super::Algorithm for VerifyingKey { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::ES512; + type Signature = ecdsa::SignatureBytes; +} + #[cfg(all(test, feature = "p256"))] mod test { diff --git a/src/algorithms/hmac.rs b/src/algorithms/hmac.rs index 4315902..51259d6 100644 --- a/src/algorithms/hmac.rs +++ b/src/algorithms/hmac.rs @@ -5,6 +5,7 @@ use std::{convert::Infallible, marker::PhantomData}; use base64ct::Encoding; +use bytes::BytesMut; use digest::Mac; use hmac::SimpleHmac; @@ -147,6 +148,43 @@ where } } +impl super::VerifyAlgorithm for Hmac +where + D: digest::Digest + + digest::Reset + + digest::core_api::BlockSizeUser + + digest::FixedOutput + + digest::core_api::CoreProxy + + Clone, + Hmac: super::Algorithm>>, +{ + type Error = digest::MacError; + type Key = HmacKey; + + fn verify( + &self, + header: &[u8], + payload: &[u8], + signature: &[u8], + ) -> Result { + // Create a new, one-shot digest for this signature. + let mut digest: SimpleHmac = + SimpleHmac::new_from_slice(self.key.as_ref()).expect("Valid key"); + let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); + message.extend_from_slice(header); + message.extend_from_slice(b"."); + message.extend_from_slice(payload); + + digest.update(message.as_ref()); + digest.clone().verify(signature.into())?; + Ok(digest.finalize().into_bytes()) + } + + fn key(&self) -> &Self::Key { + &self.key + } +} + #[cfg(test)] mod test { use crate::algorithms::SigningAlgorithm; diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index aa561ba..85f5499 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -148,9 +148,9 @@ pub trait VerifyAlgorithm: Algorithm { /// on the heap. It is used to store the signature of a JWT before it is verified, /// or if a signature has a variable length. #[derive(Debug, Clone, PartialEq, Eq, Hash, zeroize::Zeroize, zeroize::ZeroizeOnDrop)] -pub struct Signature(Vec); +pub struct SignatureBytes(Vec); -impl Signature { +impl SignatureBytes { /// Add to this signature from a byte slice. pub fn extend_from_slice(&mut self, other: &[u8]) { self.0.extend_from_slice(other); @@ -158,24 +158,24 @@ impl Signature { /// Create a new signature with the given capacity. pub fn with_capacity(capacity: usize) -> Self { - Signature(Vec::with_capacity(capacity)) + SignatureBytes(Vec::with_capacity(capacity)) } } -impl AsRef<[u8]> for Signature { +impl AsRef<[u8]> for SignatureBytes { fn as_ref(&self) -> &[u8] { &self.0 } } -impl From<&[u8]> for Signature { +impl From<&[u8]> for SignatureBytes { fn from(bytes: &[u8]) -> Self { - Signature(bytes.to_vec()) + SignatureBytes(bytes.to_vec()) } } -impl From> for Signature { +impl From> for SignatureBytes { fn from(bytes: Vec) -> Self { - Signature(bytes) + SignatureBytes(bytes) } } diff --git a/src/algorithms/rsa.rs b/src/algorithms/rsa.rs index 956f6b1..13271b9 100644 --- a/src/algorithms/rsa.rs +++ b/src/algorithms/rsa.rs @@ -17,7 +17,7 @@ use base64ct::{Base64UrlUnpadded, Encoding}; use bytes::BytesMut; -use rsa::pkcs1v15::{SigningKey, VerifyingKey}; +use rsa::pkcs1v15::SigningKey; use rsa::rand_core::OsRng; use rsa::signature::RandomizedSigner; use rsa::PublicKeyParts; @@ -70,10 +70,10 @@ where } } -impl super::VerifyAlgorithm for VerifyingKey +impl super::VerifyAlgorithm for RsaPkcs1v15 where D: digest::Digest, - VerifyingKey: super::Algorithm, + RsaPkcs1v15: super::Algorithm + Clone, { type Error = signature::Error; @@ -85,7 +85,7 @@ where payload: &[u8], signature: &[u8], ) -> Result { - use rsa::signature::Verifier; + use rsa::signature::{Keypair, Verifier}; let signature = rsa::pkcs1v15::Signature::try_from(signature).unwrap(); let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); @@ -93,7 +93,9 @@ where message.extend_from_slice(b"."); message.extend_from_slice(payload); - >::verify(self, message.as_ref(), &signature)?; + let verify = self.verifying_key(); + + verify.verify(message.as_ref(), &signature)?; Ok(signature) } diff --git a/src/base64data.rs b/src/base64data.rs index 76ed48f..77dc6d9 100644 --- a/src/base64data.rs +++ b/src/base64data.rs @@ -7,11 +7,27 @@ use std::fmt::Write; use std::marker::PhantomData; use base64ct::Encoding; -use serde::{de, ser, Serialize}; +use bytes::Bytes; +use serde::{ + de::{self, DeserializeOwned}, + ser, Serialize, +}; #[cfg(feature = "fmt")] use super::fmt::{self, IndentWriter}; +#[derive(Debug, thiserror::Error)] +pub enum DecodeError { + #[error(transparent)] + Base64(#[from] base64ct::Error), + + #[error(transparent)] + Json(#[from] serde_json::Error), + + #[error("data is not valid: {0}")] + InvalidData(#[source] Box), +} + /// Wrapper type to indicate that the inner type should be serialized /// as bytes with a Base64 URL-safe encoding. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -26,6 +42,18 @@ where } } +impl Base64Data +where + T: TryFrom>, + T::Error: std::error::Error + Send + Sync + 'static, +{ + pub(crate) fn parse(value: &str) -> Result { + let data = base64ct::Base64UrlUnpadded::decode_vec(value)?; + let data = T::try_from(data).map_err(|err| DecodeError::InvalidData(err.into()))?; + Ok(Base64Data(data)) + } +} + impl From for Base64Data { fn from(value: T) -> Self { Base64Data(value) @@ -112,6 +140,16 @@ where #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Base64JSON(pub T); +impl Base64JSON { + pub fn new(value: T) -> Self { + Base64JSON(value) + } + + pub fn into_inner(self) -> T { + self.0 + } +} + impl Base64JSON where T: Serialize, @@ -121,9 +159,30 @@ where Ok(base64ct::Base64UrlUnpadded::encode_string(&inner)) } - pub(crate) fn serialized_bytes(&self) -> Result, serde_json::Error> { - let inner = serde_json::to_vec(&self.0)?; - Ok(inner.into_boxed_slice()) + pub(crate) fn serialized_bytes(&self) -> Result { + self.serialized_value().map(|value| Bytes::from(value)) + } +} + +pub(crate) struct ParsedBase64JSON { + pub(crate) data: T, + pub(crate) bytes: Bytes, +} + +impl Base64JSON +where + T: DeserializeOwned, +{ + pub(crate) fn parse(raw: &str) -> Result, DecodeError> + where + T: de::DeserializeOwned, + { + let data = base64ct::Base64UrlUnpadded::decode_vec(raw)?; + let value = serde_json::from_slice(&data)?; + Ok(ParsedBase64JSON { + data: value, + bytes: Bytes::from(raw.to_owned()), + }) } } diff --git a/src/claims.rs b/src/claims.rs index d7e771f..64e5bce 100644 --- a/src/claims.rs +++ b/src/claims.rs @@ -22,13 +22,13 @@ use crate::fmt; /// request. /// /// Fields which are `None` are left out of the regsitered header. -#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct RegisteredClaims { /// Claim issuer identifies the principal that issued the /// JWT. The processing of this claim is generally application specific. /// The "iss" value is a case-sensitive string containing a StringOrURI /// value. Use of this claim is OPTIONAL. - #[serde(rename = "iss", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "iss", skip_serializing_if = "Option::is_none")] pub issuer: Option, /// Claim subject identifies the principal that is the @@ -38,7 +38,7 @@ pub struct RegisteredClaims, /// The "aud" (audience) claim identifies the recipients that the JWT is @@ -51,7 +51,7 @@ pub struct RegisteredClaims, /// The "exp" (expiration time) claim identifies the expiration time on or @@ -62,6 +62,7 @@ pub struct RegisteredClaims>, /// The "jti" (JWT ID) claim provides a unique identifier for the JWT. The identifier value MUST be assigned in a manner that ensures that there is a negligible probability that the same value will be accidentally assigned to a different data object; if the application uses multiple issuers, collisions MUST be prevented among values produced by different issuers as well. The "jti" claim can be used to prevent the JWT from being replayed. The "jti" value is a case- sensitive string. Use of this claim is OPTIONAL. - #[serde(rename = "jti", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "jti", skip_serializing_if = "Option::is_none")] pub token_id: Option, } +impl Default for RegisteredClaims { + fn default() -> Self { + Self { + issuer: None, + subject: None, + audience: None, + expiration: None, + not_before: None, + issued_at: None, + token_id: None, + } + } +} + #[cfg(feature = "fmt")] impl fmt::JWTFormat for RegisteredClaims where @@ -117,10 +134,10 @@ where /// They consist of "registered" header values, specified in RFC 7519, /// and a set of custom claims, which can be any arbitrary key-value /// pairs seializable as JSON. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)] pub struct Claims { /// Registered claims, which are enumerated specifically. See [RegisteredClaims]. - #[serde(flatten)] + #[serde(default, flatten)] pub registered: RegisteredClaims, /// Custom claims, which are any arbitrary JSON objects. Custom claims must implement @@ -130,6 +147,18 @@ pub struct Claims { pub claims: C, } +impl Default for Claims +where + C: Default, +{ + fn default() -> Self { + Self { + registered: Default::default(), + claims: Default::default(), + } + } +} + impl Claims { /// Create a new set of claims. Claims can also be created by constructing the /// struct literal. diff --git a/src/jose/mod.rs b/src/jose/mod.rs index cf227ba..d6fccd6 100644 --- a/src/jose/mod.rs +++ b/src/jose/mod.rs @@ -18,7 +18,6 @@ use url::Url; #[cfg(feature = "fmt")] use crate::base64data::Base64JSON; -use crate::token::TokenVerifyingError; use crate::{algorithms::AlgorithmIdentifier, key::SerializeJWK}; #[cfg(feature = "fmt")] @@ -55,6 +54,9 @@ pub enum HeaderError { #[error("invalid header type: {0}")] InvalidType(String), + #[error("invalid custom headers: {0} JSON serialized form must be an object or null")] + InvalidCustomHeaders(&'static str), + #[error("unable to serialize header value: {0}")] Serde(#[from] serde_json::Error), } @@ -70,31 +72,31 @@ pub enum HeaderError { #[derive(Debug, Clone, Serialize, Default, PartialEq, Eq, Deserialize)] struct RegisteredHeaderFields { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/jwk_set_url.md"))] - #[serde(rename = "jku", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "jku", skip_serializing_if = "Option::is_none")] jwk_set_url: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/type.md"))] - #[serde(rename = "typ", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "typ", skip_serializing_if = "Option::is_none")] r#type: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/key_id.md"))] - #[serde(rename = "kid", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "kid", skip_serializing_if = "Option::is_none")] key_id: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/certificate_url.md"))] - #[serde(rename = "x5u", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "x5u", skip_serializing_if = "Option::is_none")] pub certificate_url: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/certificate_chain.md"))] - #[serde(rename = "x5c", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "x5c", skip_serializing_if = "Option::is_none")] certificate_chain: Option>, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/content_type.md"))] - #[serde(rename = "cty", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "cty", skip_serializing_if = "Option::is_none")] content_type: Option, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/critical.md"))] - #[serde(rename = "crit", skip_serializing_if = "Option::is_none")] + #[serde(default, rename = "crit", skip_serializing_if = "Option::is_none")] critical: Option>, } @@ -161,7 +163,7 @@ impl Header { } /// Construct the JOSE header from the builder and signing key. - pub(crate) fn sign(self, key: &A::Key) -> Header> + pub(crate) fn into_signed_header(self, key: &A::Key) -> Header> where A: crate::algorithms::SigningAlgorithm, A::Key: Clone, @@ -179,21 +181,6 @@ impl Header { custom: self.custom, } } - - /// Access the JWK setting for the header. - pub fn jwk(&mut self) -> &mut KeyDerivation { - &mut self.state.key - } - - /// Access the JWK thumbprint setting for the header. - pub fn thumbprint(&mut self) -> &mut KeyDerivation> { - &mut self.state.thumbprint - } - - /// Access the JWK thumbprint setting for the header. - pub fn thumbprint_sha256(&mut self) -> &mut KeyDerivation> { - &mut self.state.thumbprint_sha256 - } } impl Header> @@ -201,14 +188,18 @@ where Key: SerializeJWK, { /// JWK signing algorithm in use. - pub fn algorithm(&self) -> &AlgorithmIdentifier { + pub(crate) fn algorithm(&self) -> &AlgorithmIdentifier { &self.state.algorithm } /// Render a signed JWK header into its rendered /// form, where the derived fields have been built /// as necessary. - pub fn render(self) -> Header + /// + /// This will cause the header to "forget" what key and algorithm + /// are used in the signature, rendering those values literally into + /// the header. + pub(crate) fn into_rendered_header(self) -> Header where H: Serialize, SignedHeader: HeaderState, @@ -219,7 +210,7 @@ where let state = RenderedHeader { raw: headers, - algorithm: self.state.algorithm, + algorithm: *self.algorithm(), key: self.state.key.build(), thumbprint: self.state.thumbprint.build(), thumbprint_sha256: self.state.thumbprint_sha256.build(), @@ -234,35 +225,41 @@ where } impl Header { - pub fn algorithm(&self) -> &AlgorithmIdentifier { + /// JWK signing algorithm in use. + pub(crate) fn algorithm(&self) -> &AlgorithmIdentifier { &self.state.algorithm } + /// Convert a rendered header into a signed header, where the algorithms + /// must match. This is used when verifying a token, but cannot check + /// the validity of any other header fields. + /// + /// # Panics + /// + /// If the key algorithm does not match the header's algorithm. #[allow(unused_variables)] - pub(crate) fn verify( - self, - key: &A::Key, - ) -> Result>, TokenVerifyingError> + pub(crate) fn into_signed_header(self, key: &A::Key) -> Header> where A: crate::algorithms::VerifyAlgorithm, { - // This may need to only verify that the algorithm header matches the key algorithm. - if A::IDENTIFIER != self.state.algorithm { - return Err(TokenVerifyingError::Algorithm( + if *self.algorithm() != A::IDENTIFIER { + panic!( + "algorithm mismatch: expected header to have {:?}, got {:?}", A::IDENTIFIER, - self.state.algorithm, - )); + self.algorithm() + ); } - Ok(Header { + + Header { state: SignedHeader { - algorithm: self.state.algorithm, + algorithm: *self.algorithm(), key: DerivedKeyValue::explicit(self.state.key), thumbprint: DerivedKeyValue::explicit(self.state.thumbprint), thumbprint_sha256: DerivedKeyValue::explicit(self.state.thumbprint_sha256), }, registered: self.registered, custom: self.custom, - }) + } } } @@ -271,6 +268,17 @@ where H: Serialize, State: HeaderState, { + /// Construct a JOSE header value. + /// + /// JOSE headers are JSON objects, so this method will serialize the header + /// into a JSON object, with keys lexicographically ordered. + /// + /// # Panics + /// + /// If a registered header were to conflict with a header owned by the type's + /// [HeaderState] implementation, this method will panic. That should not happen. + /// + /// If a custom header is not a JSON object or Null, this method will panic. pub(crate) fn value(&self) -> Result { // Re-using the parameters map here is important, because it will // alphabetize our keys, resulting in a consistent key order in rendered @@ -288,7 +296,7 @@ where } } Value::Null => {} - _ => panic!("registered headers are objects"), + _ => unreachable!("registered headers are objects"), } match custom { @@ -303,7 +311,7 @@ where } } Value::Null => {} - _ => panic!("custom headers are objects"), + _ => return Err(HeaderError::InvalidCustomHeaders(std::any::type_name::())), }; let mut map = serde_json::Map::new(); @@ -405,7 +413,7 @@ where { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] pub fn algorithm(&self) -> &AlgorithmIdentifier { - &self.header.state.algorithm + self.header.algorithm() } #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/json_web_key.md"))] @@ -427,7 +435,7 @@ where impl<'h, H> HeaderAccess<'h, H, RenderedHeader> { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] pub fn algorithm(&self) -> &AlgorithmIdentifier { - &self.header.state.algorithm + self.header.algorithm() } #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/json_web_key.md"))] @@ -530,7 +538,7 @@ where { #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] pub fn algorithm(&self) -> &AlgorithmIdentifier { - &self.header.state.algorithm + self.header.algorithm() } #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/json_web_key.md"))] @@ -570,7 +578,3 @@ impl<'h, H> HeaderAccessMut<'h, H, RenderedHeader> { &mut self.header.state.thumbprint_sha256 } } - -/// Errors returned when verifying a header. -#[derive(Debug, thiserror::Error)] -pub enum VerifyHeaderError {} diff --git a/src/jose/rendered.rs b/src/jose/rendered.rs index 7bdc09b..63573dc 100644 --- a/src/jose/rendered.rs +++ b/src/jose/rendered.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use serde::{Deserialize, Serialize}; use sha1::Sha1; use sha2::Sha256; @@ -13,10 +14,11 @@ use super::HeaderState; /// /// This is different from [super::SignedHeader] in that it contains the actual data, /// and not thd derivation, so the fields may be in inconsistent states. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct RenderedHeader { /// The raw bytes of the header, as it was signed. - pub(crate) raw: Box<[u8]>, + #[serde(skip)] + pub(crate) raw: Bytes, #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] #[serde(rename = "alg")] diff --git a/src/token/formats.rs b/src/token/formats.rs index eca34d5..f501b48 100644 --- a/src/token/formats.rs +++ b/src/token/formats.rs @@ -1,11 +1,15 @@ use std::fmt::Write; +use bytes::Bytes; +use serde::de::DeserializeOwned; use serde::{ser, Deserialize, Serialize}; -use super::{HasSignature, MaybeSigned}; +use super::{HasSignature, MaybeSigned, Unverified}; use super::{Payload, Token}; -use crate::base64data::{Base64Data, Base64JSON}; -use crate::jose::HeaderState; +use crate::algorithms::SignatureBytes; +use crate::base64data::{Base64Data, Base64JSON, DecodeError}; +use crate::jose::{HeaderState, RenderedHeader}; +use crate::Header; /// A token format that serializes the token as a compact string. #[derive(Debug, Default, Clone, Serialize, Deserialize)] @@ -66,6 +70,36 @@ pub enum TokenFormattingError { IO(#[from] std::fmt::Error), } +/// Error returned when a token cannot be parsed from a string. +/// +/// This error can occur when deserializing the header or payload +#[derive(Debug, thiserror::Error)] +pub enum TokenParseError { + /// Unable to find the header in the raw data. + #[error("missing header")] + MissingHeader, + + /// Unable to find the payload in the raw data. + #[error("missing payload")] + MissingPayload, + + /// Unable to find the signature in the raw data. + #[error("missing signature")] + MissingSignature, + + #[error(transparent)] + Utf8(#[from] std::str::Utf8Error), + + #[error(transparent)] + Base64(#[from] DecodeError), + + #[error(transparent)] + JSON(#[from] serde_json::Error), + + #[error("unexpected JSON value for {0}: {1}")] + UnexpectedJSONValue(&'static str, serde_json::Value), +} + /// Trait for token formats, defining how they are serialized. pub trait TokenFormat { /// Render the token to the given writer. @@ -79,6 +113,13 @@ pub trait TokenFormat { S: HasSignature, ::Header: Serialize, ::HeaderState: HeaderState; + + /// Parse the token from a slice. + fn parse(data: Bytes) -> Result, Self>, TokenParseError> + where + P: DeserializeOwned, + H: DeserializeOwned, + Self: Sized; } impl TokenFormat for Compact { @@ -99,6 +140,49 @@ impl TokenFormat for Compact { write!(writer, "{}.{}.{}", header, payload, signature)?; Ok(()) } + + fn parse(data: Bytes) -> Result, Self>, TokenParseError> + where + P: DeserializeOwned, + H: DeserializeOwned, + Self: Sized, + { + let mut parts = data.splitn(3, |&b| b == b'.'); + let header = { + let b64_header = + std::str::from_utf8(parts.next().ok_or(TokenParseError::MissingHeader)?)?; + + let wrapped_header = Base64JSON::>::parse(b64_header)?; + let mut header = wrapped_header.data; + header.state.raw = wrapped_header.bytes; + header + }; + + let (payload, raw_payload) = { + let b64_payload: &str = + std::str::from_utf8(parts.next().ok_or(TokenParseError::MissingPayload)?)?; + let payload: Payload

= Payload::parse(b64_payload)?; + let raw_payload: Vec = b64_payload.as_bytes().into(); + (payload, raw_payload) + }; + + let signature = { + let signature = parts.next().ok_or(TokenParseError::MissingSignature)?; + let signature: Base64Data = + Base64Data::parse(std::str::from_utf8(signature)?)?; + signature.into() + }; + + Ok(Token { + payload, + state: Unverified { + header, + signature, + payload: raw_payload.into(), + }, + fmt: Compact, + }) + } } #[derive(Debug, Serialize)] @@ -111,7 +195,7 @@ struct FlatToken<'t, P, U> { impl TokenFormat for FlatUnprotected where - U: Serialize, + U: Serialize + DeserializeOwned, { fn render( &self, @@ -139,6 +223,34 @@ where Ok(()) } + + fn parse(data: Bytes) -> Result, Self>, TokenParseError> + where + P: DeserializeOwned, + H: DeserializeOwned, + Self: Sized, + { + let value: serde_json::Value = serde_json::from_slice(&data)?; + let serde_json::Value::Object(mut object): serde_json::Value = value else { + return Err(TokenParseError::UnexpectedJSONValue("token", value)); + }; + + let Token { payload, state, .. } = parse_flat_common_values(&mut object)?; + + let unprotected = { + let unprotected = object + .remove("unprotected") + .ok_or(TokenParseError::MissingHeader)?; + let unprotected: U = serde_json::from_value(unprotected)?; + unprotected + }; + + Ok(Token { + payload, + state, + fmt: FlatUnprotected { unprotected }, + }) + } } impl Serialize for Token> @@ -146,7 +258,7 @@ where S: HasSignature, ::Header: Serialize, ::HeaderState: HeaderState, - U: Serialize, + U: Serialize + DeserializeOwned, P: Serialize, { fn serialize(&self, serializer: Ser) -> Result @@ -204,6 +316,74 @@ impl TokenFormat for Flat { Ok(()) } + + fn parse(data: Bytes) -> Result, Self>, TokenParseError> + where + P: DeserializeOwned, + H: DeserializeOwned, + Self: Sized, + { + let value: serde_json::Value = serde_json::from_slice(&data)?; + let serde_json::Value::Object(mut object): serde_json::Value = value else { + return Err(TokenParseError::UnexpectedJSONValue("token", value)); + }; + parse_flat_common_values(&mut object) + } +} + +fn parse_flat_common_values( + object: &mut serde_json::Map, +) -> Result, Flat>, TokenParseError> +where + P: DeserializeOwned, + H: DeserializeOwned, +{ + let header = { + let protected = object + .get("protected") + .ok_or(TokenParseError::MissingHeader)?; + let protected = + Base64JSON::>::parse(protected.as_str().ok_or_else( + || TokenParseError::UnexpectedJSONValue("header", protected.clone()), + )?)?; + let mut header = protected.data; + header.state.raw = protected.bytes; + header + }; + + let (payload, raw_payload) = { + let value_payload = object + .remove("payload") + .ok_or(TokenParseError::MissingPayload)?; + let b64_payload = value_payload.as_str().ok_or_else(|| { + TokenParseError::UnexpectedJSONValue("payload", value_payload.clone()) + })?; + + let payload: Payload

= Payload::parse(b64_payload)?; + let raw_payload: Vec = b64_payload.as_bytes().into(); + (payload, raw_payload) + }; + + let signature = { + let signature = object + .remove("signature") + .ok_or(TokenParseError::MissingSignature)?; + let signature: Base64Data = + Base64Data::parse(signature.as_str().ok_or_else(|| { + TokenParseError::UnexpectedJSONValue("signature", signature.clone()) + })?)?; + signature.into() + }; + + Ok(Token { + payload, + state: Unverified { + header, + signature, + payload: raw_payload.into(), + }, + fmt: Flat, + }) } impl Serialize for Token diff --git a/src/token/mod.rs b/src/token/mod.rs index c582d8e..f07b5e8 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -8,17 +8,22 @@ //! //! [RFC7519]: https://tools.ietf.org/html/rfc7519 -use std::fmt::Write; use std::marker::PhantomData; +use std::{fmt::Write, str::FromStr}; use base64ct::Encoding; -use serde::{de, ser, Deserialize, Serialize}; +use bytes::Bytes; +use serde::{ + de::{self, DeserializeOwned}, + ser, Deserialize, Serialize, +}; +use crate::algorithms::VerifyAlgorithm; #[cfg(feature = "fmt")] use crate::fmt; use crate::{ algorithms::{AlgorithmIdentifier, SigningAlgorithm}, - base64data::{Base64Data, Base64JSON}, + base64data::{Base64Data, Base64JSON, DecodeError}, jose::{HeaderAccess, HeaderAccessMut, HeaderState}, Header, }; @@ -26,6 +31,7 @@ use crate::{ mod formats; mod state; +use self::formats::TokenParseError; pub use self::formats::{Compact, Flat, FlatUnprotected, TokenFormat, TokenFormattingError}; pub use self::state::{HasSignature, MaybeSigned, Signed, Unsigned, Unverified, Verified}; @@ -68,14 +74,28 @@ where } } - fn serialized_bytes(&self) -> Result, serde_json::Error> { + fn serialized_bytes(&self) -> Result { match self { Payload::Json(data) => data.serialized_bytes(), - Payload::Empty => Ok(Box::new([])), + Payload::Empty => Ok(Bytes::new()), } } } +impl

Payload

+where + P: DeserializeOwned, +{ + fn parse(value: &str) -> Result { + if value.is_empty() { + return Ok(Payload::Empty); + } + + let parsed = Base64JSON::

::parse(value)?; + Ok(Payload::Json(parsed.data.into())) + } +} + impl

From

for Payload

{ fn from(value: P) -> Self { Payload::Json(value.into()) @@ -156,10 +176,12 @@ where /// let token = Token::compact((), ()); /// ``` /// -/// This token will have no payload, and no custom headers, but it is still usable: +/// This token will have no payload, and no custom headers. #[cfg_attr( feature = "fmt", doc = r#" +To view a debug representation of the token, use the [`fmt::JWTFormat`] trait: + ``` # use jaws::token::Token; # let token = Token::compact((), ()); @@ -305,6 +327,18 @@ where } } +impl Token, Fmt> +where + Fmt: TokenFormat, +{ + pub fn payload(&self) -> Option<&P> { + match &self.payload { + Payload::Json(data) => Some(data.as_ref()), + Payload::Empty => None, + } + } +} + impl Token, Fmt> where H: Serialize, @@ -326,7 +360,7 @@ where A::Key: Clone, // A::Signature: Serialize, { - let header = self.state.header.sign::(algorithm.key()); + let header = self.state.header.into_signed_header::(algorithm.key()); let headers = Base64JSON(&header).serialized_value()?; let payload = self.payload.serialized_value()?; let signature = algorithm @@ -355,7 +389,7 @@ where #[allow(clippy::type_complexity)] pub fn verify( self, - algorithm: A, + algorithm: &A, ) -> Result, Fmt>, TokenVerifyingError> where A: crate::algorithms::VerifyAlgorithm, @@ -379,7 +413,7 @@ where ) .map_err(TokenVerifyingError::Verify)?; - let header = self.state.header.verify::(algorithm.key())?; + let header = self.state.header.into_signed_header::(algorithm.key()); Ok(Token { payload: self.payload, @@ -389,6 +423,19 @@ where } } +impl FromStr for Token, Fmt> +where + P: DeserializeOwned, + H: DeserializeOwned, + Fmt: TokenFormat, +{ + type Err = TokenParseError; + + fn from_str(s: &str) -> Result { + Fmt::parse(Bytes::from(s.to_owned())) + } +} + impl Token, Fmt> where Fmt: TokenFormat, @@ -410,7 +457,7 @@ where payload: self.payload, state: Unverified { payload, - header: self.state.header.render(), + header: self.state.header.into_rendered_header(), signature: Base64Data(self.state.signature.as_ref().to_owned().into()), }, fmt: self.fmt, @@ -418,6 +465,61 @@ where } } +impl Token, Fmt> +where + Fmt: TokenFormat, + Alg: SigningAlgorithm, +{ + pub fn payload(&self) -> Option<&P> { + match &self.payload { + Payload::Json(data) => Some(data.as_ref()), + Payload::Empty => None, + } + } +} + +impl Token, Fmt> +where + Fmt: TokenFormat, + Alg: VerifyAlgorithm, + Alg::Key: Clone, + H: Serialize, + P: Serialize, +{ + /// Transition the token back into an unverified state. + /// + /// This method consumes the token and returns a new one, which still includes the signature + /// but which is no longer considered verified. + pub fn unverify(self) -> Token, Fmt> { + let payload = self + .payload + .serialized_bytes() + .expect("valid payload bytes"); + Token { + payload: self.payload, + state: Unverified { + payload, + header: self.state.header.into_rendered_header(), + signature: Base64Data(self.state.signature.as_ref().to_owned().into()), + }, + fmt: self.fmt, + } + } +} + +impl Token, Fmt> +where + Fmt: TokenFormat, + Alg: VerifyAlgorithm, +{ + pub fn payload(&self) -> Option<&P> { + match &self.payload { + Payload::Json(data) => Some(data.as_ref()), + Payload::Empty => None, + } + } +} + #[cfg(feature = "fmt")] impl fmt::JWTFormat for Token where @@ -508,7 +610,7 @@ pub enum TokenSigningError { } #[cfg(all(test, feature = "rsa"))] -mod test { +mod test_rsa { use super::*; use crate::claims::Claims; @@ -622,5 +724,98 @@ mod test { ) ) } + + signed.unverify().verify(&algorithm).unwrap(); + } +} + +#[cfg(all(test, feature = "ecdsa", feature = "p256"))] +mod test_ecdsa { + use super::*; + + use base64ct::Encoding; + use elliptic_curve::{FieldBytes, SecretKey}; + use serde_json::json; + use zeroize::Zeroize; + + fn strip_whitespace(s: &str) -> String { + s.chars().filter(|c| !c.is_whitespace()).collect() + } + + fn ecdsa(jwk: &serde_json::Value) -> SecretKey { + let d_b64 = strip_whitespace(jwk["d"].as_str().unwrap()); + let mut d_bytes = FieldBytes::::default(); + base64ct::Base64UrlUnpadded::decode(&d_b64, &mut d_bytes).unwrap(); + + let key = SecretKey::from_slice(&d_bytes).unwrap(); + d_bytes.zeroize(); + key + } + + #[test] + fn rfc7515_example_a3() { + let pkey = &json!({ + "kty":"EC", + "crv":"P-256", + "x":"f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU", + "y":"x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0", + "d":"jpsQnnGQmL-YBIffH1136cspYG6-0iY7X1fCE9-E9LI" + }); + + let key = ecdsa(pkey); + + let token = Token::compact((), "This is a signed message"); + + let signed = token.sign(&key).unwrap(); + + let verifying_key: ecdsa::VerifyingKey<_> = key.public_key().into(); + + let verified = signed.unverify().verify(&verifying_key).unwrap(); + + assert_eq!(verified.payload(), Some(&"This is a signed message")); + } +} + +#[cfg(all(test, feature = "hmac"))] +mod test_hmac { + use crate::algorithms::hmac::{Hmac, HmacKey}; + + use super::*; + + use base64ct::Encoding; + use serde_json::json; + use sha2::Sha256; + + fn strip_whitespace(s: &str) -> String { + s.chars().filter(|c| !c.is_whitespace()).collect() + } + + #[test] + fn rfc7515_example_a1() { + let pkey = &json!({ + "kty":"oct", + "k":"AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75 + aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow" + } + ); + + let key_data = strip_whitespace(pkey["k"].as_str().unwrap()); + + let decoded_len = 3 * key_data.len() / 4; + + let mut key = HmacKey::with_capacity(decoded_len); + key.resize(decoded_len, 0); + + base64ct::Base64UrlUnpadded::decode(&key_data, key.as_mut()).unwrap(); + + let algorithm: Hmac = Hmac::new(key); + + let token = Token::compact((), "This is an HMAC'd message"); + + let signed = token.sign(&algorithm).unwrap(); + + let verified = signed.unverify().verify(&algorithm).unwrap(); + + assert_eq!(verified.payload(), Some(&"This is an HMAC'd message")); } } diff --git a/src/token/state.rs b/src/token/state.rs index 6bb1e9c..3112ac6 100644 --- a/src/token/state.rs +++ b/src/token/state.rs @@ -1,7 +1,8 @@ -use serde::{Deserialize, Serialize}; +use bytes::Bytes; +use serde::Serialize; use crate::{ - algorithms::{Signature as SignatureBytes, SigningAlgorithm, VerifyAlgorithm}, + algorithms::{SignatureBytes, SigningAlgorithm, VerifyAlgorithm}, base64data::Base64Data, jose, }; @@ -183,13 +184,9 @@ where /// /// This state indicates that we have recieved the token from elsewhere, and /// many fields could be in inconsistnet states. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound( - serialize = "H: Serialize", - deserialize = "H: for<'deh> Deserialize<'deh>" -))] +#[derive(Debug, Clone)] pub struct Unverified { - pub(super) payload: Box<[u8]>, + pub(super) payload: Bytes, pub(super) header: jose::Header, pub(super) signature: Base64Data, } From 933192893ad9b996d987eba5c3da4616579b3fb1 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Fri, 24 Nov 2023 20:46:32 +0000 Subject: [PATCH 3/3] Fix: require verify keys for RSA signatures MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Don’t require the full private key to be present, and have the example check that adding the key as a JWK still round trips. --- README.md | 40 ++++++++++++++++++++++++++++++++++++++++ examples/rfc7515a2.rs | 39 +++++++++++++++++++++++++++++++++++++++ src/algorithms/rsa.rs | 29 +++++++++++++++++++++-------- src/base64data.rs | 2 +- src/jose/mod.rs | 1 - src/token/formats.rs | 6 +++--- src/token/mod.rs | 4 ++++ 7 files changed, 108 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0a005cf..70012a4 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ and can be run with `cargo run --example rfc7515-a2`. ```rust use jaws::Compact; +use std::ops::Deref; // JAWS provides JWT format for printing JWTs in a style similar to the example above, // which is directly inspired by the way the ACME standard shows JWTs. @@ -87,6 +88,7 @@ use jaws::JWTFormat; // signing and verification status. use jaws::Token; +use jaws::algorithms::rsa::RsaPkcs1v15Verify; // The unverified token state, used like `Token<.., Unverified<..>, ..>`. // It is generic over the type of the custom header parameters. use jaws::token::Unverified; @@ -149,6 +151,10 @@ fn main() -> Result<(), Box> { // we provide the `typ` header, which is optional in the JWT spec. *token.header_mut().r#type() = Some("JWT".to_string()); + // We can also ask that some fields be derived from the signing key, for example, + // this will derive the JWK field in the header from the signing key. + token.header_mut().key().derived(); + println!("Initial JWT"); // Initially the JWT has no defined signature: @@ -183,6 +189,15 @@ fn main() -> Result<(), Box> { println!("JWT:"); println!("{}", token.formatted()); + // We can use the JWK to verify that the token is signed with the correct key. + let hdr = token.header(); + let jwk = hdr.key().unwrap(); + let key = rsa_jwk_reader::rsa_pub(&serde_json::to_value(jwk).unwrap()); + + assert_eq!(&key, alg.as_ref().deref()); + + let alg: RsaPkcs1v15Verify = RsaPkcs1v15Verify::new_with_prefix(key); + // We can't access the claims until we verify the token. let verified = token.verify(&alg).unwrap(); @@ -196,6 +211,31 @@ fn main() -> Result<(), Box> { Ok(()) } + +mod rsa_jwk_reader { + use base64ct::Encoding; + + fn strip_whitespace(s: &str) -> String { + s.chars().filter(|c| !c.is_whitespace()).collect() + } + + fn to_biguint(v: &serde_json::Value) -> Option { + let val = strip_whitespace(v.as_str()?); + Some(rsa::BigUint::from_bytes_be( + base64ct::Base64UrlUnpadded::decode_vec(&val) + .ok()? + .as_slice(), + )) + } + + pub(crate) fn rsa_pub(key: &serde_json::Value) -> rsa::RsaPublicKey { + let n = to_biguint(&key["n"]).expect("decode n"); + let e = to_biguint(&key["e"]).expect("decode e"); + + rsa::RsaPublicKey::new(n, e).expect("valid key parameters") + } +} + ``` ## Philosophy diff --git a/examples/rfc7515a2.rs b/examples/rfc7515a2.rs index 9b78d88..716e28d 100644 --- a/examples/rfc7515a2.rs +++ b/examples/rfc7515a2.rs @@ -1,4 +1,5 @@ use jaws::Compact; +use std::ops::Deref; // JAWS provides JWT format for printing JWTs in a style similar to the example above, // which is directly inspired by the way the ACME standard shows JWTs. @@ -9,6 +10,7 @@ use jaws::JWTFormat; // signing and verification status. use jaws::Token; +use jaws::algorithms::rsa::RsaPkcs1v15Verify; // The unverified token state, used like `Token<.., Unverified<..>, ..>`. // It is generic over the type of the custom header parameters. use jaws::token::Unverified; @@ -71,6 +73,10 @@ fn main() -> Result<(), Box> { // we provide the `typ` header, which is optional in the JWT spec. *token.header_mut().r#type() = Some("JWT".to_string()); + // We can also ask that some fields be derived from the signing key, for example, + // this will derive the JWK field in the header from the signing key. + token.header_mut().key().derived(); + println!("Initial JWT"); // Initially the JWT has no defined signature: @@ -105,6 +111,15 @@ fn main() -> Result<(), Box> { println!("JWT:"); println!("{}", token.formatted()); + // We can use the JWK to verify that the token is signed with the correct key. + let hdr = token.header(); + let jwk = hdr.key().unwrap(); + let key = rsa_jwk_reader::rsa_pub(&serde_json::to_value(jwk).unwrap()); + + assert_eq!(&key, alg.as_ref().deref()); + + let alg: RsaPkcs1v15Verify = RsaPkcs1v15Verify::new_with_prefix(key); + // We can't access the claims until we verify the token. let verified = token.verify(&alg).unwrap(); @@ -118,3 +133,27 @@ fn main() -> Result<(), Box> { Ok(()) } + +mod rsa_jwk_reader { + use base64ct::Encoding; + + fn strip_whitespace(s: &str) -> String { + s.chars().filter(|c| !c.is_whitespace()).collect() + } + + fn to_biguint(v: &serde_json::Value) -> Option { + let val = strip_whitespace(v.as_str()?); + Some(rsa::BigUint::from_bytes_be( + base64ct::Base64UrlUnpadded::decode_vec(&val) + .ok()? + .as_slice(), + )) + } + + pub(crate) fn rsa_pub(key: &serde_json::Value) -> rsa::RsaPublicKey { + let n = to_biguint(&key["n"]).expect("decode n"); + let e = to_biguint(&key["e"]).expect("decode e"); + + rsa::RsaPublicKey::new(n, e).expect("valid key parameters") + } +} diff --git a/src/algorithms/rsa.rs b/src/algorithms/rsa.rs index 13271b9..55d1ed9 100644 --- a/src/algorithms/rsa.rs +++ b/src/algorithms/rsa.rs @@ -17,7 +17,6 @@ use base64ct::{Base64UrlUnpadded, Encoding}; use bytes::BytesMut; -use rsa::pkcs1v15::SigningKey; use rsa::rand_core::OsRng; use rsa::signature::RandomizedSigner; use rsa::PublicKeyParts; @@ -50,7 +49,8 @@ impl crate::key::SerializeJWK for rsa::RsaPrivateKey { } /// Alogrithm wrapper for the Digital Signature with RSASSA-PKCS1-v1_5 algorithm. -pub type RsaPkcs1v15 = SigningKey; +pub type RsaPkcs1v15 = rsa::pkcs1v15::SigningKey; +pub type RsaPkcs1v15Verify = rsa::pkcs1v15::VerifyingKey; impl super::SigningAlgorithm for RsaPkcs1v15 where @@ -70,10 +70,10 @@ where } } -impl super::VerifyAlgorithm for RsaPkcs1v15 +impl super::VerifyAlgorithm for RsaPkcs1v15Verify where D: digest::Digest, - RsaPkcs1v15: super::Algorithm + Clone, + RsaPkcs1v15Verify: super::Algorithm + Clone, { type Error = signature::Error; @@ -85,7 +85,7 @@ where payload: &[u8], signature: &[u8], ) -> Result { - use rsa::signature::{Keypair, Verifier}; + use rsa::signature::Verifier; let signature = rsa::pkcs1v15::Signature::try_from(signature).unwrap(); let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); @@ -93,9 +93,7 @@ where message.extend_from_slice(b"."); message.extend_from_slice(payload); - let verify = self.verifying_key(); - - verify.verify(message.as_ref(), &signature)?; + >::verify(self, message.as_ref(), &signature)?; Ok(signature) } @@ -119,6 +117,21 @@ impl super::Algorithm for RsaPkcs1v15 { type Signature = rsa::pkcs1v15::Signature; } +impl super::Algorithm for RsaPkcs1v15Verify { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS256; + type Signature = rsa::pkcs1v15::Signature; +} + +impl super::Algorithm for RsaPkcs1v15Verify { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS384; + type Signature = rsa::pkcs1v15::Signature; +} + +impl super::Algorithm for RsaPkcs1v15Verify { + const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS512; + type Signature = rsa::pkcs1v15::Signature; +} + /// Algorithm wrapper for RSA-PSS signatures, using [rsa::pss::BlindedSigningKey]. pub type RsaPSSKey = rsa::pss::BlindedSigningKey; diff --git a/src/base64data.rs b/src/base64data.rs index 77dc6d9..4210ab3 100644 --- a/src/base64data.rs +++ b/src/base64data.rs @@ -160,7 +160,7 @@ where } pub(crate) fn serialized_bytes(&self) -> Result { - self.serialized_value().map(|value| Bytes::from(value)) + self.serialized_value().map(Bytes::from) } } diff --git a/src/jose/mod.rs b/src/jose/mod.rs index d6fccd6..78ccb3a 100644 --- a/src/jose/mod.rs +++ b/src/jose/mod.rs @@ -16,7 +16,6 @@ use sha1::Sha1; use sha2::Sha256; use url::Url; -#[cfg(feature = "fmt")] use crate::base64data::Base64JSON; use crate::{algorithms::AlgorithmIdentifier, key::SerializeJWK}; diff --git a/src/token/formats.rs b/src/token/formats.rs index f501b48..7d3f141 100644 --- a/src/token/formats.rs +++ b/src/token/formats.rs @@ -94,7 +94,7 @@ pub enum TokenParseError { Base64(#[from] DecodeError), #[error(transparent)] - JSON(#[from] serde_json::Error), + Json(#[from] serde_json::Error), #[error("unexpected JSON value for {0}: {1}")] UnexpectedJSONValue(&'static str, serde_json::Value), @@ -170,7 +170,7 @@ impl TokenFormat for Compact { let signature = parts.next().ok_or(TokenParseError::MissingSignature)?; let signature: Base64Data = Base64Data::parse(std::str::from_utf8(signature)?)?; - signature.into() + signature }; Ok(Token { @@ -372,7 +372,7 @@ where Base64Data::parse(signature.as_str().ok_or_else(|| { TokenParseError::UnexpectedJSONValue("signature", signature.clone()) })?)?; - signature.into() + signature }; Ok(Token { diff --git a/src/token/mod.rs b/src/token/mod.rs index f07b5e8..b3c993c 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -619,6 +619,8 @@ mod test_rsa { use serde_json::json; use sha2::Sha256; + use signature::Keypair; + use crate::key::jwk_reader::rsa; fn strip_whitespace(s: &str) -> String { @@ -725,6 +727,8 @@ mod test_rsa { ) } + let algorithm = algorithm.verifying_key(); + signed.unverify().verify(&algorithm).unwrap(); } }