From 4ee74b40c12b6e9459c6b7e5d5a60fe32818b4ad Mon Sep 17 00:00:00 2001 From: Nils Hasenbanck Date: Fri, 9 Aug 2024 18:41:57 +0200 Subject: [PATCH] Make the TLS PRNG selectable via feature. --- .github/workflows/ci.yml | 10 +++++----- Cargo.toml | 8 +++++++- src/backend/aarch64.rs | 33 ++++++++++++++++++++++++++++++++- src/backend/riscv64.rs | 33 ++++++++++++++++++++++++++++++++- src/backend/soft.rs | 39 ++++++++++++++++++++++++++++++++++++++- src/backend/x86.rs | 33 ++++++++++++++++++++++++++++++++- src/tls.rs | 39 +++++++++++++++++++++++++++++++++++---- tests/tls.rs | 18 +++++++++--------- 8 files changed, 190 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f566e57..5c20a70 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -102,31 +102,31 @@ jobs: shell: bash run: | set -e - cargo test --lib + cargo test --lib --tests - name: Tests (force_software) shell: bash run: | set -e - cargo test --lib --features=force_software + cargo test --lib --tests --features=force_software - name: Tests (force_runtime_detection) shell: bash run: | set -e - cargo test --lib --features=force_runtime_detection + cargo test --lib --tests --features=force_runtime_detection - name: Tests no-std shell: bash run: | set -e - cargo test --lib --no-default-features + cargo test --lib --tests --no-default-features - name: Tests no-std (force_software) shell: bash run: | set -e - cargo test --lib --no-default-features --features=force_software + cargo test --lib --tests --no-default-features --features=force_software verification: timeout-minutes: 30 diff --git a/Cargo.toml b/Cargo.toml index 4d409a4..55466b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,14 @@ rust-version = "1.80" default = ["std", "tls", "getrandom", "rand_core"] # Used for TLS and runtime target feature detection. std = [] -# Activates the thread local functionality. +# Activates the thread local functionality (defaults to the AES-128, 64-bit counter version). tls = ["std"] +# Uses the AES-128, 128-bit counter version for the TLS instance. +tls_aes128_ctr128 = [] +# Uses the AES-256, 64-bit counter version for the TLS instance. +tls_aes256_ctr64 = [] +# Uses the AES-256, 128-bit counter version for the TLS instance. +tls_aes256_ctr128 = [] # Enables support for experimental RISC-V vector cryptography extension. Please read the README.md. experimental_riscv = [] diff --git a/src/backend/aarch64.rs b/src/backend/aarch64.rs index b14e00f..600104f 100644 --- a/src/backend/aarch64.rs +++ b/src/backend/aarch64.rs @@ -29,7 +29,14 @@ impl Drop for Aes128Ctr64 { } impl Aes128Ctr64 { - #[cfg(feature = "tls")] + #[cfg(all( + feature = "tls", + not(any( + feature = "tls_aes128_ctr128", + feature = "tls_aes256_ctr64", + feature = "tls_aes256_ctr128" + )) + ))] pub(crate) const fn zeroed() -> Self { Self { counter: Cell::new(unsafe { core::mem::zeroed() }), @@ -129,6 +136,14 @@ impl Drop for Aes128Ctr128 { } impl Aes128Ctr128 { + #[cfg(all(feature = "tls", feature = "tls_aes128_ctr128"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new(0), + round_keys: Cell::new(unsafe { core::mem::zeroed() }), + } + } + pub(crate) fn jump_impl(&self) -> Self { let clone = self.clone(); self.counter.set(self.counter.get() + (1 << 64)); @@ -222,6 +237,14 @@ impl Drop for Aes256Ctr64 { } impl Aes256Ctr64 { + #[cfg(all(feature = "tls", feature = "tls_aes256_ctr64"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new(unsafe { core::mem::zeroed() }), + round_keys: Cell::new(unsafe { core::mem::zeroed() }), + } + } + #[cfg_attr(not(target_feature = "aes"), target_feature(enable = "aes"))] #[cfg_attr(not(target_feature = "neon"), target_feature(enable = "neon"))] pub(crate) unsafe fn from_seed_impl(key: [u8; 32], nonce: [u8; 8], counter: [u8; 8]) -> Self { @@ -316,6 +339,14 @@ impl Drop for Aes256Ctr128 { } impl Aes256Ctr128 { + #[cfg(all(feature = "tls", feature = "tls_aes256_ctr128"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new(unsafe { core::mem::zeroed() }), + round_keys: Cell::new(unsafe { core::mem::zeroed() }), + } + } + pub(crate) fn jump_impl(&self) -> Self { let clone = self.clone(); self.counter.set(self.counter.get() + (1 << 64)); diff --git a/src/backend/riscv64.rs b/src/backend/riscv64.rs index ad62893..3339180 100644 --- a/src/backend/riscv64.rs +++ b/src/backend/riscv64.rs @@ -21,7 +21,14 @@ impl Drop for Aes128Ctr64 { } impl Aes128Ctr64 { - #[cfg(feature = "tls")] + #[cfg(all( + feature = "tls", + not(any( + feature = "tls_aes128_ctr128", + feature = "tls_aes256_ctr64", + feature = "tls_aes256_ctr128" + )) + ))] pub(crate) const fn zeroed() -> Self { Self { counter: Cell::new([0; 2]), @@ -157,6 +164,14 @@ impl Drop for Aes128Ctr128 { } impl Aes128Ctr128 { + #[cfg(all(feature = "tls", feature = "tls_aes128_ctr128"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new(0), + round_keys: Cell::new([0; AES128_KEY_COUNT]), + } + } + pub(crate) fn jump_impl(&self) -> Self { let clone = self.clone(); self.counter.set(self.counter.get() + (1 << 64)); @@ -295,6 +310,14 @@ impl Drop for Aes256Ctr64 { } impl Aes256Ctr64 { + #[cfg(all(feature = "tls", feature = "tls_aes256_ctr64"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new([0, 0]), + round_keys: Cell::new([0; AES256_KEY_COUNT]), + } + } + pub(crate) unsafe fn from_seed_impl(key: [u8; 32], nonce: [u8; 8], counter: [u8; 8]) -> Self { let mut key_0 = [0u8; 16]; let mut key_1 = [0u8; 16]; @@ -445,6 +468,14 @@ impl Drop for Aes256Ctr128 { } impl Aes256Ctr128 { + #[cfg(all(feature = "tls", feature = "tls_aes256_ctr128"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new(0), + round_keys: Cell::new([0; AES256_KEY_COUNT]), + } + } + pub(crate) fn jump_impl(&self) -> Self { let clone = self.clone(); self.counter.set(self.counter.get() + (1 << 64)); diff --git a/src/backend/soft.rs b/src/backend/soft.rs index 63efcf7..167d145 100644 --- a/src/backend/soft.rs +++ b/src/backend/soft.rs @@ -67,7 +67,14 @@ impl Drop for Aes128Ctr64 { } impl Aes128Ctr64 { - #[cfg(feature = "tls")] + #[cfg(all( + feature = "tls", + not(any( + feature = "tls_aes128_ctr128", + feature = "tls_aes256_ctr64", + feature = "tls_aes256_ctr128" + )) + ))] pub(crate) const fn zeroed() -> Self { Self(RefCell::new(Aes128Ctr64Inner { counter: [0; 2], @@ -164,6 +171,16 @@ impl Drop for Aes128Ctr128 { } impl Aes128Ctr128 { + #[cfg(all(feature = "tls", feature = "tls_aes128_ctr128"))] + pub(crate) fn zeroed() -> Self { + Self(RefCell::new(Aes128Ctr128Inner { + counter: 0, + round_keys: [0; FIX_SLICE_128_KEYS_SIZE], + batch_blocks: [Block::default(); BLOCK_COUNT], + batch_num: 0, + })) + } + pub(crate) fn jump_impl(&self) -> Self { let clone = self.clone(); let mut inner = self.0.borrow_mut(); @@ -260,6 +277,16 @@ impl Drop for Aes256Ctr64 { } impl Aes256Ctr64 { + #[cfg(all(feature = "tls", feature = "tls_aes256_ctr64"))] + pub(crate) fn zeroed() -> Self { + Self(RefCell::new(Aes256Ctr64Inner { + counter: [0, 0], + round_keys: [0; FIX_SLICE_256_KEYS_SIZE], + batch_blocks: [Block::default(); BLOCK_COUNT], + batch_num: 0, + })) + } + pub(crate) fn from_seed_impl(key: [u8; 32], nonce: [u8; 8], counter: [u8; 8]) -> Self { let counter = [u64::from_le_bytes(counter), u64::from_le_bytes(nonce)]; let round_keys: FixsliceKeys256 = aes256_key_expansion(key); @@ -346,6 +373,16 @@ impl Drop for Aes256Ctr128 { } impl Aes256Ctr128 { + #[cfg(all(feature = "tls", feature = "tls_aes256_ctr128"))] + pub(crate) fn zeroed() -> Self { + Self(RefCell::new(Aes256Ctr128Inner { + counter: 0, + round_keys: [0; FIX_SLICE_256_KEYS_SIZE], + batch_blocks: [Block::default(); BLOCK_COUNT], + batch_num: 0, + })) + } + pub(crate) fn jump_impl(&self) -> Self { let clone = self.clone(); let mut inner = self.0.borrow_mut(); diff --git a/src/backend/x86.rs b/src/backend/x86.rs index 5acf817..0e97779 100644 --- a/src/backend/x86.rs +++ b/src/backend/x86.rs @@ -31,7 +31,14 @@ impl Drop for Aes128Ctr64 { } impl Aes128Ctr64 { - #[cfg(feature = "tls")] + #[cfg(all( + feature = "tls", + not(any( + feature = "tls_aes128_ctr128", + feature = "tls_aes256_ctr64", + feature = "tls_aes256_ctr128" + )) + ))] pub(crate) const fn zeroed() -> Self { Self { counter: Cell::new(unsafe { core::mem::zeroed() }), @@ -129,6 +136,14 @@ impl Drop for Aes128Ctr128 { } impl Aes128Ctr128 { + #[cfg(all(feature = "tls", feature = "tls_aes128_ctr128"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new(0), + round_keys: Cell::new(unsafe { core::mem::zeroed() }), + } + } + pub(crate) fn jump_impl(&self) -> Self { let clone = self.clone(); self.counter.set(self.counter.get() + (1 << 64)); @@ -222,6 +237,14 @@ impl Drop for Aes256Ctr64 { } impl Aes256Ctr64 { + #[cfg(all(feature = "tls", feature = "tls_aes256_ctr64"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new(unsafe { core::mem::zeroed() }), + round_keys: Cell::new(unsafe { core::mem::zeroed() }), + } + } + #[cfg_attr(not(target_feature = "sse2"), target_feature(enable = "sse2"))] #[cfg_attr(not(target_feature = "aes"), target_feature(enable = "aes"))] pub(crate) unsafe fn from_seed_impl(key: [u8; 32], nonce: [u8; 8], counter: [u8; 8]) -> Self { @@ -316,6 +339,14 @@ impl Drop for Aes256Ctr128 { } impl Aes256Ctr128 { + #[cfg(all(feature = "tls", feature = "tls_aes256_ctr128"))] + pub(crate) const fn zeroed() -> Self { + Self { + counter: Cell::new(0), + round_keys: Cell::new(unsafe { core::mem::zeroed() }), + } + } + pub(crate) fn jump_impl(&self) -> Self { let clone = self.clone(); self.counter.set(self.counter.get() + (1 << 64)); diff --git a/src/tls.rs b/src/tls.rs index abb8354..3e2ce28 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -7,9 +7,40 @@ //! [`rand_seed_from_entropy()`] or [`rand_seed()`] function. use core::ops::RangeBounds; -use crate::seeds::Aes128Ctr64Seed; use crate::Random; +#[cfg(not(any( + feature = "tls_aes128_ctr128", + feature = "tls_aes256_ctr64", + feature = "tls_aes256_ctr128" +)))] +use crate::Aes128Ctr64 as Prng; + +#[cfg(feature = "tls_aes128_ctr128")] +use crate::Aes128Ctr128 as Prng; + +#[cfg(feature = "tls_aes256_ctr64")] +use crate::Aes256Ctr64 as Prng; + +#[cfg(feature = "tls_aes256_ctr128")] +use crate::Aes256Ctr128 as Prng; + +#[cfg(not(any( + feature = "tls_aes128_ctr128", + feature = "tls_aes256_ctr64", + feature = "tls_aes256_ctr128" +)))] +pub use crate::seeds::Aes128Ctr64Seed as Seed; + +#[cfg(feature = "tls_aes128_ctr128")] +pub use crate::seeds::Aes128Ctr128Seed as Seed; + +#[cfg(feature = "tls_aes256_ctr64")] +pub use crate::seeds::Aes256Ctr64Seed as Seed; + +#[cfg(feature = "tls_aes256_ctr128")] +pub use crate::seeds::Aes256Ctr128Seed as Seed; + #[cfg(any( not(any( not(any( @@ -30,7 +61,7 @@ use crate::Random; feature = "force_software" ))] thread_local! { - pub(super) static RNG: crate::Aes128Ctr64 = const { crate::Aes128Ctr64::zeroed() }; + pub(super) static RNG: Prng = const { Prng::zeroed() }; } #[cfg(all( @@ -53,7 +84,7 @@ thread_local! { not(feature = "force_software") ))] thread_local! { - pub(super) static RNG: core::cell::LazyCell = core::cell::LazyCell::new(crate::Aes128Ctr64::zeroed); + pub(super) static RNG: core::cell::LazyCell = core::cell::LazyCell::new(Prng::zeroed); } /// Seeds the thread local instance using the OS entropy source. @@ -66,7 +97,7 @@ pub fn rand_seed_from_entropy() { } /// Seeds the thread local instance with the given seed. -pub fn rand_seed(seed: Aes128Ctr64Seed) { +pub fn rand_seed(seed: Seed) { RNG.with(|rng| rng.seed(seed)) } diff --git a/tests/tls.rs b/tests/tls.rs index caa5fe4..2a3b087 100644 --- a/tests/tls.rs +++ b/tests/tls.rs @@ -1,13 +1,13 @@ #[cfg(feature = "tls")] mod test { - use rand_aes::seeds::*; + use rand_aes::tls::*; macro_rules! test_range { ($name:ident, $method:ident, $range:expr) => { #[test] fn $name() { - rand_seed(Aes128Ctr64Seed::default()); + rand_seed(Seed::default()); for _ in 0..1000 { let value = $method($range.clone()); assert!( @@ -36,7 +36,7 @@ mod test { ($name:ident, $method:ident, $max:expr) => { #[test] fn $name() { - rand_seed(Aes128Ctr64Seed::default()); + rand_seed(Seed::default()); for _ in 0..1000 { let value = $method($max); assert!( @@ -60,7 +60,7 @@ mod test { ($name:ident, $method:ident) => { #[test] fn $name() { - rand_seed(Aes128Ctr64Seed::default()); + rand_seed(Seed::default()); for _ in 0..1000 { let _value = $method(); } @@ -81,7 +81,7 @@ mod test { #[test] fn test_prng_bool() { - rand_seed(Aes128Ctr64Seed::default()); + rand_seed(Seed::default()); let mut true_count = 0; let mut false_count = 0; for _ in 0..1000 { @@ -98,7 +98,7 @@ mod test { #[test] fn test_prng_f32() { - rand_seed(Aes128Ctr64Seed::default()); + rand_seed(Seed::default()); for _ in 0..1000 { let value = rand_f32(); assert!( @@ -110,7 +110,7 @@ mod test { #[test] fn test_prng_f64() { - rand_seed(Aes128Ctr64Seed::default()); + rand_seed(Seed::default()); for _ in 0..1000 { let value = rand_f64(); assert!( @@ -122,7 +122,7 @@ mod test { #[test] fn test_prng_fill_bytes() { - rand_seed(Aes128Ctr64Seed::default()); + rand_seed(Seed::default()); let mut bytes = [0u8; 16]; rand_fill_bytes(&mut bytes); assert!( @@ -133,7 +133,7 @@ mod test { #[test] fn test_prng_shuffle() { - rand_seed(Aes128Ctr64Seed::default()); + rand_seed(Seed::default()); let mut array = [0usize; 256]; for (i, x) in array.as_mut_slice().iter_mut().enumerate() { *x = i;