Skip to content

Commit

Permalink
Signature verification support
Browse files Browse the repository at this point in the history
Adds signature verification for tokens and associated header and algorithm methods.
  • Loading branch information
alexrudy committed Nov 24, 2023
1 parent 2d4c6c1 commit 8fc2acf
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 24 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ categories = [

[dependencies]
base64ct = { version = "1.6.0", features = ["std"] }
bytes = { version = "1.5" }
chrono = { version = "0.4.24", features = ["serde"] }
elliptic-curve = { version = "0.13.4", features = [
"pkcs8",
Expand Down
68 changes: 66 additions & 2 deletions src/algorithms/ecdsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ use ::ecdsa::{
PrimeCurve, SignatureSize,
};
use base64ct::Encoding;
use bytes::BytesMut;
use digest::generic_array::ArrayLength;
use ecdsa::{hazmat::VerifyPrimitive, Signature, VerifyingKey};
use elliptic_curve::{
ops::Invert,
sec1::{Coordinates, FromEncodedPoint, ModulusSize, ToEncodedPoint},
Expand All @@ -92,6 +94,7 @@ pub use p384::NistP384;

#[cfg(feature = "p521")]
pub use p521::NistP521;
use signature::Verifier;

impl<C> crate::key::JWKeyType for PublicKey<C>
where
Expand All @@ -109,7 +112,7 @@ where
FieldBytesSize<C>: ModulusSize,
{
fn parameters(&self) -> Vec<(String, serde_json::Value)> {
let mut params = Vec::with_capacity(2);
let mut params = Vec::with_capacity(3);

params.push((
"crv".to_owned(),
Expand Down Expand Up @@ -144,6 +147,31 @@ where
self.public_key().parameters()
}
}

impl<C> crate::key::SerializeJWK for ecdsa::VerifyingKey<C>
where
C: PrimeCurve + CurveArithmetic + JwkParameters,
Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
SignatureSize<C>: ArrayLength<u8>,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldBytesSize<C>: ModulusSize,
{
fn parameters(&self) -> Vec<(String, serde_json::Value)> {
PublicKey::<C>::from(self).parameters()
}
}

impl<C> crate::key::JWKeyType for ecdsa::VerifyingKey<C>
where
C: PrimeCurve + CurveArithmetic + JwkParameters,
Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
SignatureSize<C>: ArrayLength<u8>,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldBytesSize<C>: ModulusSize,
{
const KEY_TYPE: &'static str = "EC";
}

impl<C> crate::key::SerializeJWK for ecdsa::SigningKey<C>
where
C: PrimeCurve + CurveArithmetic + JwkParameters,
Expand All @@ -153,7 +181,7 @@ where
FieldBytesSize<C>: ModulusSize,
{
fn parameters(&self) -> Vec<(String, serde_json::Value)> {
todo!()
self.verifying_key().parameters()
}
}

Expand Down Expand Up @@ -223,6 +251,42 @@ where
}
}

impl<C> super::VerifyAlgorithm for VerifyingKey<C>
where
C: PrimeCurve + CurveArithmetic + JwkParameters + ecdsa::hazmat::DigestPrimitive,
<C as CurveArithmetic>::AffinePoint: VerifyPrimitive<C>,
Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
SignatureSize<C>: ArrayLength<u8>,
MaxSize<C>: ArrayLength<u8>,
AffinePoint<C>: FromEncodedPoint<C> + ToEncodedPoint<C>,
FieldBytesSize<C>: ModulusSize,
VerifyingKey<C>: super::Algorithm<Signature = ecdsa::SignatureBytes<C>>,
<FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
{
type Error = ecdsa::Error;
type Key = VerifyingKey<C>;

fn verify(
&self,
header: &[u8],
payload: &[u8],
signature: &[u8],
) -> Result<Self::Signature, Self::Error> {
let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1);
message.extend_from_slice(header);
message.extend_from_slice(b".");
message.extend_from_slice(payload);

let signature = ecdsa::Signature::try_from(signature)?;
<Self as Verifier<Signature<C>>>::verify(self, message.as_ref(), &signature)?;
Ok(signature.into())
}

fn key(&self) -> &Self::Key {
self
}
}

#[cfg(all(test, feature = "p256"))]
mod test {

Expand Down
14 changes: 8 additions & 6 deletions src/algorithms/hmac.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! Based on the [hmac](https://crates.io/crates/hmac) crate.
use std::convert::Infallible;
use std::{convert::Infallible, marker::PhantomData};

use base64ct::Encoding;
use digest::Mac;
Expand Down Expand Up @@ -81,7 +81,7 @@ where
D: digest::Digest + digest::core_api::BlockSizeUser,
{
key: HmacKey,
digest: hmac::SimpleHmac<D>,
_digest: PhantomData<D>,
}

impl<D> Hmac<D>
Expand All @@ -93,8 +93,10 @@ where
///
/// Signing keys are arbitrary bytes.
pub fn new(key: HmacKey) -> Self {
let digest = SimpleHmac::new_from_slice(key.as_ref()).expect("Valid key");
Self { key, digest }
Self {
key,
_digest: PhantomData,
}
}

/// Reference to the HMAC key.
Expand Down Expand Up @@ -133,8 +135,8 @@ where

fn sign(&self, header: &str, payload: &str) -> Result<Self::Signature, Self::Error> {
// Create a new, one-shot digest for this signature.
let mut digest = self.digest.clone();
digest.reset();
let mut digest: SimpleHmac<D> =
SimpleHmac::new_from_slice(self.key.as_ref()).expect("Valid key");
let message = format!("{}.{}", header, payload);
digest.update(message.as_bytes());
Ok(digest.finalize().into_bytes())
Expand Down
19 changes: 17 additions & 2 deletions src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ pub enum AlgorithmIdentifier {
None,
}

impl AlgorithmIdentifier {
/// Return whether this algorithm is available for signing.
pub fn available(&self) -> bool {
match self {
Self::None => true,

Self::HS256 | Self::HS384 | Self::HS512 => cfg!(feature = "hmac"),
Self::RS256 | Self::RS384 | Self::RS512 => cfg!(feature = "rsa"),
Self::ES256 | Self::ES384 | Self::ES512 => cfg!(feature = "ecdsa"),
Self::EdDSA => cfg!(feature = "ed25519"),
Self::PS256 | Self::PS384 | Self::PS512 => cfg!(feature = "rsa"),
}
}
}

/// A trait to associate an alogritm identifier with an algorithm.
///
/// Algorithm identifiers are used in JWS and JWE to indicate how a token is signed or encrypted.
Expand Down Expand Up @@ -118,8 +133,8 @@ pub trait VerifyAlgorithm: Algorithm {
/// and payload.
fn verify(
&self,
header: &str,
payload: &str,
header: &[u8],
payload: &[u8],
signature: &[u8],
) -> Result<Self::Signature, Self::Error>;

Expand Down
35 changes: 34 additions & 1 deletion src/algorithms/rsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
//! This algorithm is used to sign and verify JSON Web Tokens using the RSASSA-PSS.
use base64ct::{Base64UrlUnpadded, Encoding};
use rsa::pkcs1v15::SigningKey;
use bytes::BytesMut;
use rsa::pkcs1v15::{SigningKey, VerifyingKey};
use rsa::rand_core::OsRng;
use rsa::signature::RandomizedSigner;
use rsa::PublicKeyParts;
Expand Down Expand Up @@ -69,6 +70,38 @@ where
}
}

impl<D> super::VerifyAlgorithm for VerifyingKey<D>
where
D: digest::Digest,
VerifyingKey<D>: super::Algorithm<Signature = rsa::pkcs1v15::Signature>,
{
type Error = signature::Error;

type Key = rsa::RsaPublicKey;

fn verify(
&self,
header: &[u8],
payload: &[u8],
signature: &[u8],
) -> Result<Self::Signature, Self::Error> {
use rsa::signature::Verifier;
let signature = rsa::pkcs1v15::Signature::try_from(signature).unwrap();

let mut message = BytesMut::with_capacity(header.len() + payload.len() + 1);
message.extend_from_slice(header);
message.extend_from_slice(b".");
message.extend_from_slice(payload);

<Self as Verifier<rsa::pkcs1v15::Signature>>::verify(self, message.as_ref(), &signature)?;
Ok(signature)
}

fn key(&self) -> &Self::Key {
self.as_ref()
}
}

impl super::Algorithm for RsaPkcs1v15<sha2::Sha256> {
const IDENTIFIER: super::AlgorithmIdentifier = super::AlgorithmIdentifier::RS256;
type Signature = rsa::pkcs1v15::Signature;
Expand Down
5 changes: 5 additions & 0 deletions src/base64data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ where
let inner = serde_json::to_vec(&self.0)?;
Ok(base64ct::Base64UrlUnpadded::encode_string(&inner))
}

pub(crate) fn serialized_bytes(&self) -> Result<Box<[u8]>, serde_json::Error> {
let inner = serde_json::to_vec(&self.0)?;
Ok(inner.into_boxed_slice())
}
}

impl<T> AsRef<T> for Base64JSON<T> {
Expand Down
7 changes: 7 additions & 0 deletions src/jose/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ where
KeyDerivation::Explicit(value) => DerivedKeyValue::Explicit(value),
}
}

pub(super) fn explicit(value: Option<Builder::Value>) -> Self {
match value {
Some(value) => DerivedKeyValue::Explicit(value),
None => DerivedKeyValue::Omit,
}
}
}

impl<Builder, Key> ser::Serialize for DerivedKeyValue<Builder, Key>
Expand Down
37 changes: 33 additions & 4 deletions src/jose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use url::Url;

#[cfg(feature = "fmt")]
use crate::base64data::Base64JSON;
use crate::token::TokenVerifyingError;
use crate::{algorithms::AlgorithmIdentifier, key::SerializeJWK};

#[cfg(feature = "fmt")]
Expand Down Expand Up @@ -108,7 +109,7 @@ const REGISTERED_HEADER_KEYS: [&str; 11] = [
#[non_exhaustive]
pub struct Header<H, State> {
#[serde(flatten)]
state: State,
pub(crate) state: State,

/// The set of registered header parameters from [JWS][] and [JWA][].
///
Expand Down Expand Up @@ -207,8 +208,17 @@ where
/// Render a signed JWK header into its rendered
/// form, where the derived fields have been built
/// as necessary.
pub fn render(self) -> Header<H, RenderedHeader> {
pub fn render(self) -> Header<H, RenderedHeader>
where
H: Serialize,
SignedHeader<Key>: HeaderState,
{
let headers = Base64JSON(&self)
.serialized_bytes()
.expect("valid header value");

let state = RenderedHeader {
raw: headers,
algorithm: self.state.algorithm,
key: self.state.key.build(),
thumbprint: self.state.thumbprint.build(),
Expand All @@ -229,11 +239,30 @@ impl<H> Header<H, RenderedHeader> {
}

#[allow(unused_variables)]
pub(crate) fn verify<A>(self, key: &A::Key) -> Result<Header<H, SignedHeader<A::Key>>, A::Error>
pub(crate) fn verify<A>(
self,
key: &A::Key,
) -> Result<Header<H, SignedHeader<A::Key>>, TokenVerifyingError<A::Error>>

Check failure on line 245 in src/jose/mod.rs

View workflow job for this annotation

GitHub Actions / clippy

very complex type used. Consider factoring parts into `type` definitions

error: very complex type used. Consider factoring parts into `type` definitions --> src/jose/mod.rs:245:10 | 245 | ) -> Result<Header<H, SignedHeader<A::Key>>, TokenVerifyingError<A::Error>> | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#type_complexity = note: `-D clippy::type-complexity` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(clippy::type_complexity)]`
where
A: crate::algorithms::VerifyAlgorithm,
{
todo!("verify");
// This may need to only verify that the algorithm header matches the key algorithm.
if A::IDENTIFIER != self.state.algorithm {
return Err(TokenVerifyingError::Algorithm(
A::IDENTIFIER,
self.state.algorithm,
));
}
Ok(Header {
state: SignedHeader {
algorithm: self.state.algorithm,
key: DerivedKeyValue::explicit(self.state.key),
thumbprint: DerivedKeyValue::explicit(self.state.thumbprint),
thumbprint_sha256: DerivedKeyValue::explicit(self.state.thumbprint_sha256),
},
registered: self.registered,
custom: self.custom,
})
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/jose/rendered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ use super::HeaderState;
/// and not thd derivation, so the fields may be in inconsistent states.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RenderedHeader {
/// The raw bytes of the header, as it was signed.
pub(crate) raw: Box<[u8]>,

#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/docs/jose/algorithm.md"))]
#[serde(rename = "alg")]
pub(super) algorithm: AlgorithmIdentifier,
Expand Down
Loading

0 comments on commit 8fc2acf

Please sign in to comment.