diff --git a/Cargo.toml b/Cargo.toml index 33a0328..737c4af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ categories = [ [dependencies] base64ct = { version = "1.6.0", features = ["std"] } +bytes = { version = "1.5" } chrono = { version = "0.4.24", features = ["serde"] } elliptic-curve = { version = "0.13.4", features = [ "pkcs8", diff --git a/src/algorithms/ecdsa.rs b/src/algorithms/ecdsa.rs index 7d76f3c..75a1235 100644 --- a/src/algorithms/ecdsa.rs +++ b/src/algorithms/ecdsa.rs @@ -74,7 +74,9 @@ use ::ecdsa::{ PrimeCurve, SignatureSize, }; use base64ct::Encoding; +use bytes::BytesMut; use digest::generic_array::ArrayLength; +use ecdsa::{hazmat::VerifyPrimitive, Signature, VerifyingKey}; use elliptic_curve::{ ops::Invert, sec1::{Coordinates, FromEncodedPoint, ModulusSize, ToEncodedPoint}, @@ -92,6 +94,7 @@ pub use p384::NistP384; #[cfg(feature = "p521")] pub use p521::NistP521; +use signature::Verifier; impl crate::key::JWKeyType for PublicKey where @@ -109,7 +112,7 @@ where FieldBytesSize: ModulusSize, { fn parameters(&self) -> Vec<(String, serde_json::Value)> { - let mut params = Vec::with_capacity(2); + let mut params = Vec::with_capacity(3); params.push(( "crv".to_owned(), @@ -144,6 +147,31 @@ where self.public_key().parameters() } } + +impl crate::key::SerializeJWK for ecdsa::VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, +{ + fn parameters(&self) -> Vec<(String, serde_json::Value)> { + PublicKey::::from(self).parameters() + } +} + +impl crate::key::JWKeyType for ecdsa::VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, +{ + const KEY_TYPE: &'static str = "EC"; +} + impl crate::key::SerializeJWK for ecdsa::SigningKey where C: PrimeCurve + CurveArithmetic + JwkParameters, @@ -153,7 +181,7 @@ where FieldBytesSize: ModulusSize, { fn parameters(&self) -> Vec<(String, serde_json::Value)> { - todo!() + self.verifying_key().parameters() } } @@ -223,6 +251,42 @@ where } } +impl super::VerifyAlgorithm for VerifyingKey +where + C: PrimeCurve + CurveArithmetic + JwkParameters + ecdsa::hazmat::DigestPrimitive, + ::AffinePoint: VerifyPrimitive, + Scalar: Invert>> + SignPrimitive, + SignatureSize: ArrayLength, + MaxSize: ArrayLength, + AffinePoint: FromEncodedPoint + ToEncodedPoint, + FieldBytesSize: ModulusSize, + VerifyingKey: super::Algorithm>, + as Add>::Output: Add + ArrayLength, +{ + type Error = ecdsa::Error; + type Key = VerifyingKey; + + fn verify( + &self, + header: &[u8], + payload: &[u8], + signature: &[u8], + ) -> Result { + let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); + message.extend_from_slice(header); + message.extend_from_slice(b"."); + message.extend_from_slice(payload); + + let signature = ecdsa::Signature::try_from(signature)?; + >>::verify(self, message.as_ref(), &signature)?; + Ok(signature.into()) + } + + fn key(&self) -> &Self::Key { + self + } +} + #[cfg(all(test, feature = "p256"))] mod test { diff --git a/src/algorithms/hmac.rs b/src/algorithms/hmac.rs index fd9b616..4315902 100644 --- a/src/algorithms/hmac.rs +++ b/src/algorithms/hmac.rs @@ -2,7 +2,7 @@ //! //! Based on the [hmac](https://crates.io/crates/hmac) crate. -use std::convert::Infallible; +use std::{convert::Infallible, marker::PhantomData}; use base64ct::Encoding; use digest::Mac; @@ -81,7 +81,7 @@ where D: digest::Digest + digest::core_api::BlockSizeUser, { key: HmacKey, - digest: hmac::SimpleHmac, + _digest: PhantomData, } impl Hmac @@ -93,8 +93,10 @@ where /// /// Signing keys are arbitrary bytes. pub fn new(key: HmacKey) -> Self { - let digest = SimpleHmac::new_from_slice(key.as_ref()).expect("Valid key"); - Self { key, digest } + Self { + key, + _digest: PhantomData, + } } /// Reference to the HMAC key. @@ -133,8 +135,8 @@ where fn sign(&self, header: &str, payload: &str) -> Result { // Create a new, one-shot digest for this signature. - let mut digest = self.digest.clone(); - digest.reset(); + let mut digest: SimpleHmac = + SimpleHmac::new_from_slice(self.key.as_ref()).expect("Valid key"); let message = format!("{}.{}", header, payload); digest.update(message.as_bytes()); Ok(digest.finalize().into_bytes()) diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 8b017e7..aa561ba 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -69,6 +69,21 @@ pub enum AlgorithmIdentifier { None, } +impl AlgorithmIdentifier { + /// Return whether this algorithm is available for signing. + pub fn available(&self) -> bool { + match self { + Self::None => true, + + Self::HS256 | Self::HS384 | Self::HS512 => cfg!(feature = "hmac"), + Self::RS256 | Self::RS384 | Self::RS512 => cfg!(feature = "rsa"), + Self::ES256 | Self::ES384 | Self::ES512 => cfg!(feature = "ecdsa"), + Self::EdDSA => cfg!(feature = "ed25519"), + Self::PS256 | Self::PS384 | Self::PS512 => cfg!(feature = "rsa"), + } + } +} + /// A trait to associate an alogritm identifier with an algorithm. /// /// Algorithm identifiers are used in JWS and JWE to indicate how a token is signed or encrypted. @@ -118,8 +133,8 @@ pub trait VerifyAlgorithm: Algorithm { /// and payload. fn verify( &self, - header: &str, - payload: &str, + header: &[u8], + payload: &[u8], signature: &[u8], ) -> Result; diff --git a/src/algorithms/rsa.rs b/src/algorithms/rsa.rs index 78e22d2..956f6b1 100644 --- a/src/algorithms/rsa.rs +++ b/src/algorithms/rsa.rs @@ -16,7 +16,8 @@ //! This algorithm is used to sign and verify JSON Web Tokens using the RSASSA-PSS. use base64ct::{Base64UrlUnpadded, Encoding}; -use rsa::pkcs1v15::SigningKey; +use bytes::BytesMut; +use rsa::pkcs1v15::{SigningKey, VerifyingKey}; use rsa::rand_core::OsRng; use rsa::signature::RandomizedSigner; use rsa::PublicKeyParts; @@ -69,6 +70,38 @@ where } } +impl super::VerifyAlgorithm for VerifyingKey +where + D: digest::Digest, + VerifyingKey: super::Algorithm, +{ + type Error = signature::Error; + + type Key = rsa::RsaPublicKey; + + fn verify( + &self, + header: &[u8], + payload: &[u8], + signature: &[u8], + ) -> Result { + use rsa::signature::Verifier; + let signature = rsa::pkcs1v15::Signature::try_from(signature).unwrap(); + + let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1); + message.extend_from_slice(header); + message.extend_from_slice(b"."); + message.extend_from_slice(payload); + + >::verify(self, message.as_ref(), &signature)?; + Ok(signature) + } + + fn key(&self) -> &Self::Key { + self.as_ref() + } +} + impl super::Algorithm for RsaPkcs1v15 { const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS256; type Signature = rsa::pkcs1v15::Signature; diff --git a/src/base64data.rs b/src/base64data.rs index d747358..76ed48f 100644 --- a/src/base64data.rs +++ b/src/base64data.rs @@ -120,6 +120,11 @@ where let inner = serde_json::to_vec(&self.0)?; Ok(base64ct::Base64UrlUnpadded::encode_string(&inner)) } + + pub(crate) fn serialized_bytes(&self) -> Result, serde_json::Error> { + let inner = serde_json::to_vec(&self.0)?; + Ok(inner.into_boxed_slice()) + } } impl AsRef for Base64JSON { diff --git a/src/jose/derive.rs b/src/jose/derive.rs index 12ff501..fa1c7bc 100644 --- a/src/jose/derive.rs +++ b/src/jose/derive.rs @@ -106,6 +106,13 @@ where KeyDerivation::Explicit(value) => DerivedKeyValue::Explicit(value), } } + + pub(super) fn explicit(value: Option) -> Self { + match value { + Some(value) => DerivedKeyValue::Explicit(value), + None => DerivedKeyValue::Omit, + } + } } impl ser::Serialize for DerivedKeyValue diff --git a/src/jose/mod.rs b/src/jose/mod.rs index 74b88d8..cf227ba 100644 --- a/src/jose/mod.rs +++ b/src/jose/mod.rs @@ -18,6 +18,7 @@ use url::Url; #[cfg(feature = "fmt")] use crate::base64data::Base64JSON; +use crate::token::TokenVerifyingError; use crate::{algorithms::AlgorithmIdentifier, key::SerializeJWK}; #[cfg(feature = "fmt")] @@ -108,7 +109,7 @@ const REGISTERED_HEADER_KEYS: [&str; 11] = [ #[non_exhaustive] pub struct Header { #[serde(flatten)] - state: State, + pub(crate) state: State, /// The set of registered header parameters from [JWS][] and [JWA][]. /// @@ -207,8 +208,17 @@ where /// Render a signed JWK header into its rendered /// form, where the derived fields have been built /// as necessary. - pub fn render(self) -> Header { + pub fn render(self) -> Header + where + H: Serialize, + SignedHeader: HeaderState, + { + let headers = Base64JSON(&self) + .serialized_bytes() + .expect("valid header value"); + let state = RenderedHeader { + raw: headers, algorithm: self.state.algorithm, key: self.state.key.build(), thumbprint: self.state.thumbprint.build(), @@ -229,11 +239,30 @@ impl Header { } #[allow(unused_variables)] - pub(crate) fn verify(self, key: &A::Key) -> Result>, A::Error> + pub(crate) fn verify( + self, + key: &A::Key, + ) -> Result>, TokenVerifyingError> where A: crate::algorithms::VerifyAlgorithm, { - todo!("verify"); + // This may need to only verify that the algorithm header matches the key algorithm. + if A::IDENTIFIER != self.state.algorithm { + return Err(TokenVerifyingError::Algorithm( + A::IDENTIFIER, + self.state.algorithm, + )); + } + Ok(Header { + state: SignedHeader { + algorithm: self.state.algorithm, + key: DerivedKeyValue::explicit(self.state.key), + thumbprint: DerivedKeyValue::explicit(self.state.thumbprint), + thumbprint_sha256: DerivedKeyValue::explicit(self.state.thumbprint_sha256), + }, + registered: self.registered, + custom: self.custom, + }) } } diff --git a/src/jose/rendered.rs b/src/jose/rendered.rs index 9f177d9..7bdc09b 100644 --- a/src/jose/rendered.rs +++ b/src/jose/rendered.rs @@ -15,6 +15,9 @@ use super::HeaderState; /// and not thd derivation, so the fields may be in inconsistent states. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RenderedHeader { + /// The raw bytes of the header, as it was signed. + pub(crate) raw: Box<[u8]>, + #[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))] #[serde(rename = "alg")] pub(super) algorithm: AlgorithmIdentifier, diff --git a/src/token/mod.rs b/src/token/mod.rs index 8b004ac..c582d8e 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -67,6 +67,13 @@ where Payload::Empty => Ok("".to_owned()), } } + + fn serialized_bytes(&self) -> Result, serde_json::Error> { + match self { + Payload::Json(data) => data.serialized_bytes(), + Payload::Empty => Ok(Box::new([])), + } + } } impl

From

for Payload

{ @@ -364,16 +371,16 @@ where } let signature = &self.state.signature; - let header = self - .state - .header - .verify::(algorithm.key()) - .map_err(TokenVerifyingError::Verify)?; - let headers = Base64JSON(&header).serialized_value()?; - let payload = self.payload.serialized_value()?; let signature = algorithm - .verify(&headers, &payload, signature.as_ref()) + .verify( + &self.state.header.state.raw, + &self.state.payload, + signature.as_ref(), + ) .map_err(TokenVerifyingError::Verify)?; + + let header = self.state.header.verify::(algorithm.key())?; + Ok(Token { payload: self.payload, state: Verified { header, signature }, @@ -386,15 +393,23 @@ impl Token, Fmt> where Fmt: TokenFormat, Alg: SigningAlgorithm, + Alg::Key: Clone, + H: Serialize, + P: Serialize, { /// Transition the token back into an unverified state. /// /// This method consumes the token and returns a new one, which still includes the signature /// but which is no longer considered verified. pub fn unverify(self) -> Token, Fmt> { + let payload = self + .payload + .serialized_bytes() + .expect("valid payload bytes"); Token { payload: self.payload, state: Unverified { + payload, header: self.state.header.render(), signature: Base64Data(self.state.signature.as_ref().to_owned().into()), }, @@ -408,7 +423,7 @@ impl fmt::JWTFormat for Token where S: HasSignature, ::Header: Serialize, - ::HeaderState: Serialize + HeaderState, + ::HeaderState: HeaderState, P: Serialize, Fmt: TokenFormat, { diff --git a/src/token/state.rs b/src/token/state.rs index 9b1db01..6bb1e9c 100644 --- a/src/token/state.rs +++ b/src/token/state.rs @@ -189,6 +189,7 @@ where deserialize = "H: for<'deh> Deserialize<'deh>" ))] pub struct Unverified { + pub(super) payload: Box<[u8]>, pub(super) header: jose::Header, pub(super) signature: Base64Data, }