diff --git a/src/symmetric_crypto/nonce.rs b/src/symmetric_crypto/nonce.rs index b785771..d364cca 100644 --- a/src/symmetric_crypto/nonce.rs +++ b/src/symmetric_crypto/nonce.rs @@ -8,7 +8,6 @@ use core::{ convert::{TryFrom, TryInto}, fmt::{Debug, Display}, }; -use num_bigint::BigUint; use rand_core::{CryptoRng, RngCore}; /// Trait defining a nonce for use in a symmetric encryption scheme. @@ -26,7 +25,7 @@ pub trait NonceTrait: Send + Sync + Sized + Clone { /// Increment the nonce by the given value. #[must_use] - fn increment(&self, increment: usize) -> Self; + fn increment(&self, increment: u64) -> Self; /// Xor the nonce with the given value. #[must_use] @@ -62,12 +61,26 @@ impl NonceTrait for Nonce { } #[inline] - fn increment(&self, increment: usize) -> Self { - let mut bi = BigUint::from_bytes_le(&self.0); - bi += BigUint::from(increment); - let mut bi_bytes = bi.to_bytes_le(); - bi_bytes.resize(NONCE_LENGTH, 0); - Self(bi_bytes.try_into().expect("This should never happen")) + fn increment(&self, increment: u64) -> Self { + let increment = increment.to_le_bytes(); + assert!(NONCE_LENGTH > 8, "Consider using a longer Nonce!"); + // add the first bytes + let mut res = [0; NONCE_LENGTH]; + let mut carry = 0; + for (i, (b1, b2)) in self.0.iter().zip(increment).enumerate() { + (res[i], carry) = adc(*b1, b2, carry); + } + // take into account the potentially remaining carry + res[increment.len()] = self.0[8] + carry; + // copy the rest of the input nonce + for (res, b) in res + .iter_mut() + .rev() + .zip(self.0.iter().rev().take(NONCE_LENGTH - 7)) + { + *res = *b; + } + Self(res) } #[inline] @@ -85,6 +98,12 @@ impl NonceTrait for Nonce { } } +#[inline] +const fn adc(a: u8, b: u8, carry: u8) -> (u8, u8) { + let ret = (a as u16) + (b as u16) + (carry as u16); + (ret as u8, (ret >> 8) as u8) +} + impl<'a, const NONCE_LENGTH: usize> TryFrom<&'a [u8]> for Nonce { type Error = CryptoCoreError; @@ -100,7 +119,7 @@ impl From<[u8; NONCE_LENGTH]> for Nonce } impl Display for Nonce { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", hex::encode(self.0)) } } @@ -128,7 +147,7 @@ mod tests { fn test_increment_nonce() { const NONCE_LENGTH: usize = 12; let mut nonce: Nonce = Nonce::from([0_u8; NONCE_LENGTH]); - let inc = 1_usize << 10; + let inc = 1 << 10; nonce = nonce.increment(inc); println!("{}", hex::encode(nonce.0)); assert_eq!("000400000000000000000000", hex::encode(nonce.0));