Skip to content

Commit

Permalink
cleanup and disallow unwraps
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Dec 17, 2024
1 parent d675b13 commit cd99409
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 72 deletions.
50 changes: 27 additions & 23 deletions src/algorithms/rsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub fn rsa_decrypt<R: CryptoRngCore + ?Sized>(
let bits = d.bits_precision();

let c = if let Some(ref mut rng) = rng {
let (blinded, unblinder) = blind(rng, priv_key, c, &n_params);
let (blinded, unblinder) = blind(rng, priv_key, c, n_params);
ir = Some(unblinder);
blinded.widen(bits)
} else {
Expand All @@ -60,15 +60,15 @@ pub fn rsa_decrypt<R: CryptoRngCore + ?Sized>(

let m = if is_multiprime || !has_precomputes {
// c^d (mod n)
pow_mod_params(&c, d, n_params.clone())
pow_mod_params(&c, d, n_params)
} else {
// We have the precalculated values needed for the CRT.

let dp = priv_key.dp().unwrap();
let dq = priv_key.dq().unwrap();
let qinv = priv_key.qinv().unwrap();
let p_params = priv_key.p_params().unwrap();
let q_params = priv_key.q_params().unwrap();
let dp = priv_key.dp().expect("precomputed");
let dq = priv_key.dq().expect("precomputed");
let qinv = priv_key.qinv().expect("precomputed");
let p_params = priv_key.p_params().expect("precomputed");
let q_params = priv_key.q_params().expect("precomputed");

let _p = &priv_key.primes()[0];
let q = &priv_key.primes()[1];
Expand Down Expand Up @@ -166,23 +166,23 @@ fn blind<R: CryptoRngCore, K: PublicKeyParts>(

let blinded = {
// r^e (mod n)
let mut rpowe = pow_mod_params(&r, key.e(), n_params.clone());
let mut rpowe = pow_mod_params(&r, key.e(), n_params);
// c * r^e (mod n)
let c = mul_mod_params(c, &rpowe, n_params.clone());
let c = mul_mod_params(c, &rpowe, n_params);
rpowe.zeroize();

c
};

let ir = ir.unwrap();
let ir = ir.expect("loop exited");
debug_assert_eq!(blinded.bits_precision(), bits);
debug_assert_eq!(ir.bits_precision(), bits);

(blinded, ir)
}

/// Given an m and and unblinding factor, unblind the m.
fn unblind(m: &BoxedUint, unblinder: &BoxedUint, n_params: BoxedMontyParams) -> BoxedUint {
fn unblind(m: &BoxedUint, unblinder: &BoxedUint, n_params: &BoxedMontyParams) -> BoxedUint {
// m * r^-1 (mod n)
debug_assert_eq!(
m.bits_precision(),
Expand All @@ -200,16 +200,16 @@ fn unblind(m: &BoxedUint, unblinder: &BoxedUint, n_params: BoxedMontyParams) ->
}

/// Computes `base.pow_mod(exp, n)` with precomputed `n_params`.
fn pow_mod_params(base: &BoxedUint, exp: &BoxedUint, n_params: BoxedMontyParams) -> BoxedUint {
fn pow_mod_params(base: &BoxedUint, exp: &BoxedUint, n_params: &BoxedMontyParams) -> BoxedUint {
let base = reduce(base, n_params);
base.pow(exp).retrieve()
}

/// Computes `lhs.mul_mod(rhs, n)` with precomputed `n_params`.
fn mul_mod_params(lhs: &BoxedUint, rhs: &BoxedUint, n_params: BoxedMontyParams) -> BoxedUint {
fn mul_mod_params(lhs: &BoxedUint, rhs: &BoxedUint, n_params: &BoxedMontyParams) -> BoxedUint {
// TODO: nicer api in crypto-bigint?
let lhs = BoxedMontyForm::new(lhs.clone(), n_params.clone());
let rhs = BoxedMontyForm::new(rhs.clone(), n_params);
let rhs = BoxedMontyForm::new(rhs.clone(), n_params.clone());
(lhs * rhs).retrieve()
}

Expand Down Expand Up @@ -247,11 +247,11 @@ pub fn recover_primes(

// 3. Let b = ( (n – r)/(m + 1) ) + 1; if b is not an integer or b^2 ≤ 4n, then output an error indicator,
// and exit without further processing.
let modulus_check = (&n - &r) % NonZero::new(&m + &one).unwrap();
let modulus_check = (&n - &r) % NonZero::new(&m + &one).expect("adding 1");
if (!modulus_check.is_zero()).into() {
return Err(Error::InvalidArguments);
}
let b = ((&n - &r) / NonZero::new(&m + &one).unwrap()) + one;
let b = ((&n - &r) / NonZero::new(&m + &one).expect("adding one")) + one;

let four = BoxedUint::from(4u32);
let four_n = &n * four;
Expand All @@ -273,7 +273,9 @@ pub fn recover_primes(
}

let bits = core::cmp::max(b.bits_precision(), y.bits_precision());
let two = NonZero::new(BoxedUint::from(2u64)).unwrap().widen(bits);
let two = NonZero::new(BoxedUint::from(2u64))
.expect("2 is non zero")
.widen(bits);
let p = (&b + &y) / &two;
let q = (b - y) / two;

Expand All @@ -282,11 +284,12 @@ pub fn recover_primes(

/// Compute the modulus of a key from its primes.
pub(crate) fn compute_modulus(primes: &[BoxedUint]) -> Odd<BoxedUint> {
let mut out = primes[0].clone();
for p in &primes[1..] {
let mut primes = primes.iter();
let mut out = primes.next().expect("must at least be one prime").clone();
for p in primes {
out = out * p;
}
Odd::new(out).unwrap()
Odd::new(out).expect("modulus must be odd")
}

/// Compute the private exponent from its primes (p and q) and public exponent
Expand Down Expand Up @@ -329,12 +332,13 @@ pub(crate) fn compute_private_exponent_carmicheal(
q: &BoxedUint,
exp: &BoxedUint,
) -> Result<BoxedUint> {
let p1 = p - &BoxedUint::one();
let q1 = q - &BoxedUint::one();
let one = BoxedUint::one();
let p1 = p - &one;
let q1 = q - &one;

// LCM inlined
let gcd = p1.gcd(&q1);
let lcm = p1 / NonZero::new(gcd).unwrap() * &q1;
let lcm = p1 / NonZero::new(gcd).expect("gcd is non zero") * &q1;
let exp = exp.widen(lcm.bits_precision());
if let Some(d) = exp.inv_mod(&lcm).into() {
Ok(d)
Expand Down
97 changes: 57 additions & 40 deletions src/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ pub struct RsaPublicKey {
/// Public exponent: power to which a plaintext message is raised in
/// order to encrypt it.
///
/// Typically 0x10001 (65537)
/// Typically `0x10001` (`65537`)
e: BoxedUint,

n_params: BoxedMontyParams,
}

impl Eq for RsaPublicKey {}

impl PartialEq for RsaPublicKey {
#[inline]
fn eq(&self, other: &RsaPublicKey) -> bool {
Expand All @@ -63,7 +64,7 @@ pub struct RsaPrivateKey {
pub(crate) d: BoxedUint,
/// Prime factors of N, contains >= 2 elements.
pub(crate) primes: Vec<BoxedUint>,
/// precomputed values to speed up private operations
/// Precomputed values to speed up private operations
pub(crate) precomputed: Option<PrecomputedValues>,
}

Expand Down Expand Up @@ -110,14 +111,21 @@ pub(crate) struct PrecomputedValues {
/// Q^-1 mod P
pub(crate) qinv: BoxedMontyForm,

/// Montgomery params for `p`
pub(crate) p_params: BoxedMontyParams,
/// Montgomery params for `q`
pub(crate) q_params: BoxedMontyParams,
}

impl ZeroizeOnDrop for PrecomputedValues {}

impl Zeroize for PrecomputedValues {
fn zeroize(&mut self) {
self.dp.zeroize();
self.dq.zeroize();
// TODO: once these have landed in crypto-bigint
// self.p_params.zeroize();
// self.q_params.zeroize();
}
}

Expand All @@ -141,7 +149,7 @@ impl From<&RsaPrivateKey> for RsaPublicKey {
RsaPublicKey {
n: n.clone(),
e: e.clone(),
n_params,
n_params: n_params.clone(),
}
}
}
Expand All @@ -155,8 +163,8 @@ impl PublicKeyParts for RsaPublicKey {
&self.e
}

fn n_params(&self) -> BoxedMontyParams {
self.n_params.clone()
fn n_params(&self) -> &BoxedMontyParams {
&self.n_params
}
}

Expand Down Expand Up @@ -204,7 +212,9 @@ impl RsaPublicKey {
pub fn new_with_max_size(n: BoxedUint, e: BoxedUint, max_size: usize) -> Result<Self> {
check_public_with_max_size(&n, &e, max_size)?;

let n_odd = Odd::new(n.clone()).unwrap();
let n_odd = Odd::new(n.clone())
.into_option()
.ok_or(Error::InvalidModulus)?;
let n_params = BoxedMontyParams::new(n_odd);
let n = NonZero::new(n).expect("checked above");

Expand All @@ -218,9 +228,9 @@ impl RsaPublicKey {
/// Most applications should use [`RsaPublicKey::new`] or
/// [`RsaPublicKey::new_with_max_size`] instead.
pub fn new_unchecked(n: BoxedUint, e: BoxedUint) -> Self {
let n_odd = Odd::new(n.clone()).unwrap();
let n_odd = Odd::new(n.clone()).expect("n must be odd");
let n_params = BoxedMontyParams::new(n_odd);
let n = NonZero::new(n).unwrap();
let n = NonZero::new(n).expect("odd numbers are non zero");

Self { n, e, n_params }
}
Expand All @@ -235,8 +245,8 @@ impl PublicKeyParts for RsaPrivateKey {
&self.pubkey_components.e
}

fn n_params(&self) -> BoxedMontyParams {
self.pubkey_components.n_params.clone()
fn n_params(&self) -> &BoxedMontyParams {
&self.pubkey_components.n_params
}
}

Expand Down Expand Up @@ -282,17 +292,20 @@ impl RsaPrivateKey {
mut primes: Vec<BoxedUint>,
) -> Result<RsaPrivateKey> {
let n_params = BoxedMontyParams::new(n.clone());
let n_c = NonZero::new(n.as_ref().clone()).unwrap();

if primes.len() < 2 {
if !primes.is_empty() {
return Err(Error::NprimesTooSmall);
let n_c = NonZero::new(n.get())
.into_option()
.ok_or(Error::InvalidModulus)?;

match primes.len() {
0 => {
// Recover `p` and `q` from `d`.
// See method in Appendix C.2: https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Br2.pdf
let (p, q) = recover_primes(&n_c, &e, &d)?;
primes.push(p);
primes.push(q);
}
// Recover `p` and `q` from `d`.
// See method in Appendix C.2: https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Br2.pdf
let (p, q) = recover_primes(&n_c, &e, &d)?;
primes.push(p);
primes.push(q);
1 => return Err(Error::NprimesTooSmall),
_ => {}
}

let mut k = RsaPrivateKey {
Expand All @@ -309,8 +322,8 @@ impl RsaPrivateKey {
// Alaways validate the key, to ensure precompute can't fail
k.validate()?;

// precompute when possible, ignore error otherwise.
let _ = k.precompute();
// Precompute when possible, ignore error otherwise.
k.precompute().ok();

Ok(k)
}
Expand All @@ -330,10 +343,11 @@ impl RsaPrivateKey {
return Err(Error::InvalidPrime);
}

let n = compute_modulus(&[p.clone(), q.clone()]);
let d = compute_private_exponent_carmicheal(&p, &q, &public_exponent)?;
let primes = vec![p, q];
let n = compute_modulus(&primes);

Self::from_components(n, public_exponent, d, vec![p, q])
Self::from_components(n, public_exponent, d, primes)
}

/// Constructs an RSA key pair from its primes.
Expand All @@ -347,7 +361,7 @@ impl RsaPrivateKey {
return Err(Error::NprimesTooSmall);
}

// Makes sure that primes is pairwise unequal.
// Makes sure that the primes are pairwise unequal.
for (i, prime1) in primes.iter().enumerate() {
for prime2 in primes.iter().take(i) {
if prime1 == prime2 {
Expand Down Expand Up @@ -381,25 +395,27 @@ impl RsaPrivateKey {
let p = self.primes[0].widen(bits);
let q = self.primes[1].widen(bits);

// TODO: error handling

let p_odd = Odd::new(p.clone()).unwrap();
let p_odd = Odd::new(p.clone())
.into_option()
.ok_or(Error::InvalidPrime)?;
let p_params = BoxedMontyParams::new(p_odd);
let q_odd = Odd::new(q.clone()).unwrap();
let q_odd = Odd::new(q.clone())
.into_option()
.ok_or(Error::InvalidPrime)?;
let q_params = BoxedMontyParams::new(q_odd);

let x = NonZero::new(p.wrapping_sub(&BoxedUint::one())).unwrap();
let x = NonZero::new(p.wrapping_sub(&BoxedUint::one()))
.into_option()
.ok_or(Error::InvalidPrime)?;
let dp = d.rem_vartime(&x);

let x = NonZero::new(q.wrapping_sub(&BoxedUint::one())).unwrap();
let x = NonZero::new(q.wrapping_sub(&BoxedUint::one()))
.into_option()
.ok_or(Error::InvalidPrime)?;
let dq = d.rem_vartime(&x);

let qinv = BoxedMontyForm::new(q.clone(), p_params.clone());
let qinv = qinv.invert();
if qinv.is_none().into() {
return Err(Error::InvalidPrime);
}
let qinv = qinv.unwrap();
let qinv = qinv.invert().into_option().ok_or(Error::InvalidPrime)?;

debug_assert_eq!(dp.bits_precision(), bits);
debug_assert_eq!(dq.bits_precision(), bits);
Expand Down Expand Up @@ -438,9 +454,10 @@ impl RsaPrivateKey {

// Check that Πprimes == n.
let mut m = BoxedUint::one_with_precision(self.pubkey_components.n.bits_precision());
let one = BoxedUint::one();
for prime in &self.primes {
// Any primes ≤ 1 will cause divide-by-zero panics later.
if prime < &BoxedUint::one() {
if prime < &one {
return Err(Error::InvalidPrime);
}
m = m.wrapping_mul(prime);
Expand Down Expand Up @@ -577,9 +594,9 @@ fn check_public_with_max_size(n: &BoxedUint, e: &BoxedUint, max_size: usize) ->
Ok(())
}

pub(crate) fn reduce(n: &BoxedUint, p: BoxedMontyParams) -> BoxedMontyForm {
pub(crate) fn reduce(n: &BoxedUint, p: &BoxedMontyParams) -> BoxedMontyForm {
let bits_precision = p.modulus().bits_precision();
let modulus = NonZero::new(p.modulus().as_ref().clone()).unwrap();
let modulus = p.modulus().as_nz_ref().clone();

let n = match n.bits_precision().cmp(&bits_precision) {
Ordering::Less => n.widen(bits_precision),
Expand All @@ -588,7 +605,7 @@ pub(crate) fn reduce(n: &BoxedUint, p: BoxedMontyParams) -> BoxedMontyForm {
};

let n_reduced = n.rem_vartime(&modulus).widen(p.bits_precision());
BoxedMontyForm::new(n_reduced, p)
BoxedMontyForm::new(n_reduced, p.clone())
}

#[cfg(feature = "serde")]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#![doc = include_str!("../README.md")]
#![doc(html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo_small.png")]
#![warn(missing_docs)]
#![cfg_attr(not(test), deny(clippy::unwrap_used))]

//! # Supported algorithms
//!
Expand Down
8 changes: 4 additions & 4 deletions src/pkcs1v15/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ impl TryFrom<&[u8]> for Signature {

fn try_from(bytes: &[u8]) -> signature::Result<Self> {
let len = bytes.len();
Ok(Self {
// TODO: how to convert error?
inner: BoxedUint::from_be_slice(bytes, len as u32 * 8).unwrap(),
})
let inner = BoxedUint::from_be_slice(bytes, len as u32 * 8)
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync + 'static>)?;

Ok(Self { inner })
}
}

Expand Down
Loading

0 comments on commit cd99409

Please sign in to comment.