From 724aa468717d78e55d0c2325463fe546e9e50e96 Mon Sep 17 00:00:00 2001 From: Nils Hasenbanck Date: Tue, 6 Aug 2024 11:59:50 +0200 Subject: [PATCH] Support RISCV64 vector extension. --- .cargo/config.toml | 2 +- .github/workflows/ci.yml | 14 +- Cargo.toml | 2 + README.md | 27 +- scripts/run_verification.sh | 4 +- src/fallback/mod.rs | 35 +- src/fallback/runtime.rs | 8 +- src/hardware/mod.rs | 4 +- src/hardware/riscv64.rs | 848 ++++++++++++++------------- src/implementation.rs | 4 +- src/lib.rs | 32 +- src/tls.rs | 4 +- src/verification.rs | 16 +- verification/Cargo.toml | 7 +- verification/src/bin/verification.rs | 4 +- 15 files changed, 504 insertions(+), 507 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index b4985b0..3fb8d8d 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -2,7 +2,7 @@ rustflags = ["-C", "target-feature=+neon,+aes"] [target.'cfg(target_arch="riscv64")'] -rustflags = ["-C", "target-feature=+zkne"] +rustflags = ["-C", "target-feature=+v,+zvkn"] [target.'cfg(target_arch="x86_64")'] rustflags = ["-C", "target-feature=+sse2,+aes"] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index be685d9..89fd2f8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,17 +19,14 @@ jobs: - name: Linux x86_64 os: ubuntu-24.04 target: x86_64-unknown-linux-gnu - channel: stable - name: Linux riscv64gc os: ubuntu-24.04 target: riscv64gc-unknown-linux-gnu - channel: nightly - name: MacOS aarch64 os: macos-latest target: aarch64-apple-darwin - channel: stable name: Clippy ${{ matrix.name }} runs-on: ${{ matrix.os }} @@ -40,8 +37,8 @@ jobs: - name: Install toolchain run: | - rustup toolchain install ${{ matrix.channel }} --no-self-update --profile=minimal --component clippy --target ${{ matrix.target }} - rustup override set ${{ matrix.channel }} + rustup toolchain install stable --no-self-update --profile=minimal --component clippy --target ${{ matrix.target }} + rustup override set stable cargo -V - name: Caching @@ -139,16 +136,13 @@ jobs: - name: aarch64 arch: aarch64 target: aarch64-unknown-linux-gnu - channel: stable cpu: a64fx - name: riscv64 arch: riscv64 target: riscv64gc-unknown-linux-gnu - channel: nightly - name: x86_64 arch: x86-64 target: x86_64-unknown-linux-gnu - channel: stable name: Validate ${{ matrix.name }} runs-on: ubuntu-24.04 @@ -165,8 +159,8 @@ jobs: - name: Install toolchain run: | - rustup toolchain install ${{ matrix.channel }} --no-self-update --profile=minimal --target ${{ matrix.target }} - rustup override set ${{ matrix.channel }} + rustup toolchain install stable --no-self-update --profile=minimal --target ${{ matrix.target }} + rustup override set stable cargo -V - name: Caching diff --git a/Cargo.toml b/Cargo.toml index ac3d316..b27a2c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,8 @@ default = ["std", "tls", "getrandom", "rand_core"] std = [] # Activates the thread local functionality. tls = ["std"] +# Enables support for experimental RISC-V vector cryptography extension. Please read the README.md. +experimental_riscv = [] ### The following features are only used internally and are unstable ### # Forces the compiler to always use the fallback (never using the hardware AES directly). diff --git a/README.md b/README.md index 31cfbe0..61abde0 100644 --- a/README.md +++ b/README.md @@ -66,10 +66,29 @@ We provide a software implementation of AES in case there is no hardware acceler accelerated versions for the following architectures: - aarch64: Support since Cortex-A53 (2012). -- riscv64: Must support the scalar based cryptography extension (zk). +- riscv64: Experimental using the vector crypto extension. - x86_64: Support since Intel's Westmere (2010) and AMD's Bulldozer (2011). -riscv64 needs nightly Rust, since the AES intrinsics are not marked as stable yet. +## Experimental RISC-V support + +There are two AES extensions for RISC-V, the scalar (zkn) and the vector crypto extensions (zvkn). Currently, there +are no hardware based CPUs that support either of them. It's also not clear which platform will favor which. Since +this crate mainly targets application class architectures (opposed to embedded), we think that providing a vector +crypto implementation is the safest bet. Since there are no intrinsics for the vector crypto extension, we provide +a handwritten ASM implementation. Since there is currently no way to discover the support for the vector crypto +extension (August 2024), you need to select the needed target features and an experimental create feature at compile +time. The generated executable will only run on systems with a vector crypto extension. + +Activate the target features vor the vector extension and the vector crypto extension. This can be done for example +inside the `.cargo/config.toml`: + +```toml +[target.'cfg(target_arch="riscv64")'] +rustflags = ["-C", "target-feature=+v,+zvkn"] +``` + +You also need to select the `experimental_riscv` create feature. This feature is experimental and will +most likely become absolute in the future (once intrinsics and runtime discovery are available). ## Optimal Performance @@ -83,7 +102,7 @@ much better performance. The runtime detection is not supported in `no_std`. Use the following target features for optimal performance: - aarch64: "aes" (using the cryptographic extension) -- riscv64: "zkne" (using the scalar based cryptography extension) +- riscv64: "v" and "zvkn" (using the vector and vector crypto extension) - x86_64: "aes" (using AES-NI) Example in `.cargo/config.toml`: @@ -93,7 +112,7 @@ Example in `.cargo/config.toml`: rustflags = ["-C", "target-feature=+aes"] [target.'cfg(target_arch="riscv64")'] -rustflags = ["-C", "target-feature=+zkne"] +rustflags = ["-C", "target-feature=+v,+zvkn"] [target.'cfg(target_arch="x86_64")'] rustflags = ["-C", "target-feature=+aes"] diff --git a/scripts/run_verification.sh b/scripts/run_verification.sh index cf79e92..101791d 100755 --- a/scripts/run_verification.sh +++ b/scripts/run_verification.sh @@ -25,8 +25,8 @@ case $ARCH in qemu-aarch64 -cpu cortex-a53 -L /usr/aarch64-linux-gnu ../target/aarch64-unknown-linux-gnu/release/verification ;; "riscv64") - CARGO_TARGET_RISCV64GC_UNKNOWN_LINUX_GNU_LINKER=riscv64-linux-gnu-gcc cargo +nightly build --release --target=riscv64gc-unknown-linux-gnu - qemu-riscv64 -cpu rv64,zk=true -L /usr/riscv64-linux-gnu ../target/riscv64gc-unknown-linux-gnu/release/verification + CARGO_TARGET_RISCV64GC_UNKNOWN_LINUX_GNU_LINKER=riscv64-linux-gnu-gcc cargo build --release --target=riscv64gc-unknown-linux-gnu --no-default-features --features=experimental_riscv + qemu-riscv64 -cpu rv64,v=true,vlen=128,zvkn=true -L /usr/riscv64-linux-gnu ../target/riscv64gc-unknown-linux-gnu/release/verification ;; "x86_64") CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=x86_64-linux-gnu-gcc cargo build --release --target=x86_64-unknown-linux-gnu diff --git a/src/fallback/mod.rs b/src/fallback/mod.rs index ff5f230..90f0b60 100644 --- a/src/fallback/mod.rs +++ b/src/fallback/mod.rs @@ -5,14 +5,7 @@ //! - Fixed: Always uses the software AES implementation. #[cfg(all( any( - not(all( - feature = "std", - any( - target_arch = "aarch64", - target_arch = "riscv64", - target_arch = "x86_64", - ) - )), + not(all(feature = "std", any(target_arch = "aarch64", target_arch = "x86_64",))), feature = "force_no_runtime_detection" ), not(feature = "verification") @@ -21,14 +14,7 @@ mod fixed; #[cfg(all( not(any( - not(all( - feature = "std", - any( - target_arch = "aarch64", - target_arch = "riscv64", - target_arch = "x86_64", - ) - )), + not(all(feature = "std", any(target_arch = "aarch64", target_arch = "x86_64",))), feature = "force_no_runtime_detection" )), not(feature = "verification") @@ -37,13 +23,9 @@ mod runtime; pub(crate) mod software; -#[cfg(all( +#[cfg(all( feature = "std", - any( - target_arch = "aarch64", - target_arch = "riscv64", - target_arch = "x86_64", - ), + any(target_arch = "aarch64", target_arch = "x86_64",), not(feature = "force_no_runtime_detection"), not(feature = "verification") ))] @@ -51,14 +33,7 @@ pub use runtime::{Aes128Ctr128, Aes128Ctr64, Aes256Ctr128, Aes256Ctr64}; #[cfg(all( any( - not(all( - feature = "std", - any( - target_arch = "aarch64", - target_arch = "riscv64", - target_arch = "x86_64", - ) - )), + not(all(feature = "std", any(target_arch = "aarch64", target_arch = "x86_64",))), feature = "force_no_runtime_detection" ), not(feature = "verification") diff --git a/src/fallback/runtime.rs b/src/fallback/runtime.rs index 694b8b5..2c1deb9 100644 --- a/src/fallback/runtime.rs +++ b/src/fallback/runtime.rs @@ -23,10 +23,6 @@ pub(crate) fn has_hardware_acceleration() -> bool { { return true; } - #[cfg(target_arch = "riscv64")] - if std::arch::is_riscv_feature_detected!("zkne") { - return true; - } false } @@ -265,9 +261,7 @@ impl Aes256Ctr64 { // Safety: We checked that the hardware acceleration is available. unsafe { this.next_impl() } } - Aes256Ctr64Inner::Software(this) => { - this.borrow_mut().next_impl() - }, + Aes256Ctr64Inner::Software(this) => this.borrow_mut().next_impl(), } } } diff --git a/src/hardware/mod.rs b/src/hardware/mod.rs index 360bdb3..8b56e0a 100644 --- a/src/hardware/mod.rs +++ b/src/hardware/mod.rs @@ -21,9 +21,7 @@ pub use x86_64::{Aes128Ctr128, Aes128Ctr64, Aes256Ctr128, Aes256Ctr64}; target_feature = "sse2", target_feature = "aes", ), - all( - target_arch = "riscv64", target_feature = "zkne" - ), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", diff --git a/src/hardware/riscv64.rs b/src/hardware/riscv64.rs index 593d803..79b6975 100644 --- a/src/hardware/riscv64.rs +++ b/src/hardware/riscv64.rs @@ -1,6 +1,6 @@ -use core::{arch::riscv64::*, cell::Cell}; +use core::{arch::asm, cell::Cell}; -use crate::constants::{AES128_KEY_COUNT, AES256_KEY_COUNT}; +use crate::constants::{AES128_KEY_COUNT, AES128_KEY_SIZE, AES256_KEY_COUNT, AES256_KEY_SIZE}; /// A random number generator based on the AES-128 block cipher that runs in CTR mode and has a /// period of 64-bit. @@ -9,13 +9,13 @@ use crate::constants::{AES128_KEY_COUNT, AES256_KEY_COUNT}; #[derive(Clone)] pub struct Aes128Ctr64 { counter: Cell<[u64; 2]>, - round_keys: Cell<[[u64; 2]; AES128_KEY_COUNT]>, + round_keys: Cell<[u128; AES128_KEY_COUNT]>, } impl Drop for Aes128Ctr64 { fn drop(&mut self) { self.counter.set([0, 0]); - self.round_keys.set([[0; 2]; AES128_KEY_COUNT]); + self.round_keys.set([0; AES128_KEY_COUNT]); core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst); } } @@ -25,20 +25,16 @@ impl Aes128Ctr64 { pub(crate) const fn zeroed() -> Self { Self { counter: Cell::new([0; 2]), - round_keys: Cell::new([[0; 2]; AES128_KEY_COUNT]), + round_keys: Cell::new([0; AES128_KEY_COUNT]), } } - #[target_feature(enable = "zkne")] pub(crate) unsafe fn from_seed_impl(key: [u8; 16], nonce: [u8; 8], counter: [u8; 8]) -> Self { - let mut key_lo = [0u8; 8]; - let mut key_hi = [0u8; 8]; - - key_lo.copy_from_slice(&key[0..8]); - key_hi.copy_from_slice(&key[8..16]); + let mut key_0 = [0u8; 16]; + key_0.copy_from_slice(&key[0..16]); let counter = [u64::from_le_bytes(counter), u64::from_le_bytes(nonce)]; - let key = [u64::from_le_bytes(key_lo), u64::from_le_bytes(key_hi)]; + let key = u128::from_le_bytes(key_0); let round_keys = aes128_key_expansion(key); @@ -48,16 +44,12 @@ impl Aes128Ctr64 { } } - #[target_feature(enable = "zkne")] pub(crate) unsafe fn seed_impl(&self, key: [u8; 16], nonce: [u8; 8], counter: [u8; 8]) { - let mut key_lo = [0u8; 8]; - let mut key_hi = [0u8; 8]; - - key_lo.copy_from_slice(&key[0..8]); - key_hi.copy_from_slice(&key[8..16]); + let mut key_0 = [0u8; 16]; + key_0.copy_from_slice(&key[0..16]); let counter = [u64::from_le_bytes(counter), u64::from_le_bytes(nonce)]; - let key = [u64::from_le_bytes(key_lo), u64::from_le_bytes(key_hi)]; + let key = u128::from_le_bytes(key_0); let round_keys = aes128_key_expansion(key); @@ -73,8 +65,7 @@ impl Aes128Ctr64 { self.counter.get()[0] } - #[cfg_attr(target_feature = "zkne", inline(always))] - #[cfg_attr(not(target_feature = "zkne"), target_feature(enable = "zkne"))] + #[inline(always)] pub(crate) unsafe fn next_impl(&self) -> u128 { // Increment the lower 64 bits. let counter = self.counter.get(); @@ -83,49 +74,64 @@ impl Aes128Ctr64 { self.counter.set(new_counter); let round_keys = self.round_keys.get(); - - // We apply the AES encryption on the counter. - let mut state = [counter[0] ^ round_keys[0][0], counter[1] ^ round_keys[0][1]]; - - let mut temp0 = aes64esm(state[0], state[1]); - let mut temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[1][0], temp1 ^ round_keys[1][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[2][0], temp1 ^ round_keys[2][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[3][0], temp1 ^ round_keys[3][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[4][0], temp1 ^ round_keys[4][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[5][0], temp1 ^ round_keys[5][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[6][0], temp1 ^ round_keys[6][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[7][0], temp1 ^ round_keys[7][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[8][0], temp1 ^ round_keys[8][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[9][0], temp1 ^ round_keys[9][1]]; - - temp0 = aes64es(state[0], state[1]); - temp1 = aes64es(state[1], state[0]); - state = [temp0 ^ round_keys[10][0], temp1 ^ round_keys[10][1]]; + let mut round_keys_ptr = (&round_keys).as_ptr(); + + // Initialize the state with the counter. + let mut state = counter; + let state_ptr = (&mut state).as_mut_ptr(); + + asm!( + "vsetivli x0, 4, e32, m1, ta, ma", + "vle32.v v0, (t0)", // Load counter into a register + "vle32.v v1, (t1)", // Copy all round keys into the vector registers + "addi t1, t1, 16", + "vle32.v v2, (t1)", + "addi t1, t1, 16", + "vle32.v v3, (t1)", + "addi t1, t1, 16", + "vle32.v v4, (t1)", + "addi t1, t1, 16", + "vle32.v v5, (t1)", + "addi t1, t1, 16", + "vle32.v v6, (t1)", + "addi t1, t1, 16", + "vle32.v v7, (t1)", + "addi t1, t1, 16", + "vle32.v v8, (t1)", + "addi t1, t1, 16", + "vle32.v v9, (t1)", + "addi t1, t1, 16", + "vle32.v v10, (t1)", + "addi t1, t1, 16", + "vle32.v v11, (t1)", + "vaesz.vs v0, v1", // Whiten the counter + "vaesem.vs v0, v2", // Apply 10 rounds of AES + "vaesem.vs v0, v3", + "vaesem.vs v0, v4", + "vaesem.vs v0, v5", + "vaesem.vs v0, v6", + "vaesem.vs v0, v7", + "vaesem.vs v0, v8", + "vaesem.vs v0, v9", + "vaesem.vs v0, v10", + "vaesef.vs v0, v11", + "vse32.v v0, (t0)", // Store the state + options(nostack), + in("t0") state_ptr, + inlateout("t1") round_keys_ptr, + out("v0") _, + out("v1") _, + out("v2") _, + out("v3") _, + out("v4") _, + out("v5") _, + out("v6") _, + out("v7") _, + out("v8") _, + out("v9") _, + out("v10") _, + out("v11") _, + ); // Return the encrypted counter as u128. u128::from(state[0]) | (u128::from(state[1]) << 64) @@ -139,13 +145,13 @@ impl Aes128Ctr64 { #[derive(Clone)] pub struct Aes128Ctr128 { counter: Cell, - round_keys: Cell<[[u64; 2]; AES128_KEY_COUNT]>, + round_keys: Cell<[u128; AES128_KEY_COUNT]>, } impl Drop for Aes128Ctr128 { fn drop(&mut self) { self.counter.set(0); - self.round_keys.set([[0; 2]; AES128_KEY_COUNT]); + self.round_keys.set([0; AES128_KEY_COUNT]); core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst); } } @@ -163,16 +169,12 @@ impl Aes128Ctr128 { clone } - #[target_feature(enable = "zkne")] pub(crate) unsafe fn from_seed_impl(key: [u8; 16], counter: [u8; 16]) -> Self { - let mut key_lo = [0u8; 8]; - let mut key_hi = [0u8; 8]; - - key_lo.copy_from_slice(&key[0..8]); - key_hi.copy_from_slice(&key[8..16]); + let mut key_0 = [0u8; 16]; + key_0.copy_from_slice(&key[0..16]); let counter = u128::from_le_bytes(counter); - let key = [u64::from_le_bytes(key_lo), u64::from_le_bytes(key_hi)]; + let key = u128::from_le_bytes(key_0); let round_keys = aes128_key_expansion(key); @@ -182,16 +184,12 @@ impl Aes128Ctr128 { } } - #[target_feature(enable = "zkne")] pub(crate) unsafe fn seed_impl(&self, key: [u8; 16], counter: [u8; 16]) { - let mut key_lo = [0u8; 8]; - let mut key_hi = [0u8; 8]; - - key_lo.copy_from_slice(&key[0..8]); - key_hi.copy_from_slice(&key[8..16]); + let mut key_0 = [0u8; 16]; + key_0.copy_from_slice(&key[0..16]); let counter = u128::from_le_bytes(counter); - let key = [u64::from_le_bytes(key_lo), u64::from_le_bytes(key_hi)]; + let key = u128::from_le_bytes(key_0); let round_keys = aes128_key_expansion(key); @@ -207,65 +205,74 @@ impl Aes128Ctr128 { self.counter.get() } - #[cfg_attr(target_feature = "zkne", inline(always))] - #[cfg_attr(not(target_feature = "zkne"), target_feature(enable = "zkne"))] + #[inline(always)] pub(crate) unsafe fn next_impl(&self) -> u128 { // Increment the counter. let counter = self.counter.get(); self.counter.set(counter.wrapping_add(1)); let round_keys = self.round_keys.get(); - let counter_low = counter as u64; - let counter_high = (counter >> 64) as u64; - - // We apply the AES encryption on the counter. - let mut state = [ - counter_low ^ round_keys[0][0], - counter_high ^ round_keys[0][1], - ]; - - let mut temp0 = aes64esm(state[0], state[1]); - let mut temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[1][0], temp1 ^ round_keys[1][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[2][0], temp1 ^ round_keys[2][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[3][0], temp1 ^ round_keys[3][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[4][0], temp1 ^ round_keys[4][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[5][0], temp1 ^ round_keys[5][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[6][0], temp1 ^ round_keys[6][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[7][0], temp1 ^ round_keys[7][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[8][0], temp1 ^ round_keys[8][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[9][0], temp1 ^ round_keys[9][1]]; - - temp0 = aes64es(state[0], state[1]); - temp1 = aes64es(state[1], state[0]); - state = [temp0 ^ round_keys[10][0], temp1 ^ round_keys[10][1]]; + let mut round_keys_ptr = (&round_keys).as_ptr(); + + // Initialize the state with the counter. + let mut state = counter; + let state_ptr = (&mut state) as *mut u128; + + asm!( + "vsetivli x0, 4, e32, m1, ta, ma", + "vle32.v v0, (t0)", // Load counter into a register + "vle32.v v1, (t1)", // Copy all round keys into the vector registers + "addi t1, t1, 16", + "vle32.v v2, (t1)", + "addi t1, t1, 16", + "vle32.v v3, (t1)", + "addi t1, t1, 16", + "vle32.v v4, (t1)", + "addi t1, t1, 16", + "vle32.v v5, (t1)", + "addi t1, t1, 16", + "vle32.v v6, (t1)", + "addi t1, t1, 16", + "vle32.v v7, (t1)", + "addi t1, t1, 16", + "vle32.v v8, (t1)", + "addi t1, t1, 16", + "vle32.v v9, (t1)", + "addi t1, t1, 16", + "vle32.v v10, (t1)", + "addi t1, t1, 16", + "vle32.v v11, (t1)", + "vaesz.vs v0, v1", // Whiten the counter + "vaesem.vs v0, v2", // Apply 10 rounds of AES + "vaesem.vs v0, v3", + "vaesem.vs v0, v4", + "vaesem.vs v0, v5", + "vaesem.vs v0, v6", + "vaesem.vs v0, v7", + "vaesem.vs v0, v8", + "vaesem.vs v0, v9", + "vaesem.vs v0, v10", + "vaesef.vs v0, v11", + "vse32.v v0, (t0)", // Store the state + options(nostack), + in("t0") state_ptr, + inlateout("t1") round_keys_ptr, + out("v0") _, + out("v1") _, + out("v2") _, + out("v3") _, + out("v4") _, + out("v5") _, + out("v6") _, + out("v7") _, + out("v8") _, + out("v9") _, + out("v10") _, + out("v11") _, + ); // Return the encrypted counter as u128. - u128::from(state[0]) | (u128::from(state[1]) << 64) + state } } @@ -276,37 +283,27 @@ impl Aes128Ctr128 { #[derive(Clone)] pub struct Aes256Ctr64 { counter: Cell<[u64; 2]>, - round_keys: Cell<[[u64; 2]; AES256_KEY_COUNT]>, + round_keys: Cell<[u128; AES256_KEY_COUNT]>, } impl Drop for Aes256Ctr64 { fn drop(&mut self) { self.counter.set([0, 0]); - self.round_keys.set([[0; 2]; AES256_KEY_COUNT]); + self.round_keys.set([0; AES256_KEY_COUNT]); core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst); } } impl Aes256Ctr64 { - #[target_feature(enable = "zkne")] pub(crate) unsafe fn from_seed_impl(key: [u8; 32], nonce: [u8; 8], counter: [u8; 8]) -> Self { - let mut key_0 = [0u8; 8]; - let mut key_1 = [0u8; 8]; - let mut key_2 = [0u8; 8]; - let mut key_3 = [0u8; 8]; + let mut key_0 = [0u8; 16]; + let mut key_1 = [0u8; 16]; - key_0.copy_from_slice(&key[0..8]); - key_1.copy_from_slice(&key[8..16]); - key_2.copy_from_slice(&key[16..24]); - key_3.copy_from_slice(&key[24..32]); + key_0.copy_from_slice(&key[0..16]); + key_1.copy_from_slice(&key[16..32]); let counter = [u64::from_le_bytes(counter), u64::from_le_bytes(nonce)]; - let key = [ - u64::from_le_bytes(key_0), - u64::from_le_bytes(key_1), - u64::from_le_bytes(key_2), - u64::from_le_bytes(key_3), - ]; + let key = [u128::from_le_bytes(key_0), u128::from_le_bytes(key_1)]; let round_keys = aes256_key_expansion(key); @@ -316,25 +313,15 @@ impl Aes256Ctr64 { } } - #[target_feature(enable = "zkne")] pub(crate) unsafe fn seed_impl(&self, key: [u8; 32], nonce: [u8; 8], counter: [u8; 8]) { - let mut key_0 = [0u8; 8]; - let mut key_1 = [0u8; 8]; - let mut key_2 = [0u8; 8]; - let mut key_3 = [0u8; 8]; + let mut key_0 = [0u8; 16]; + let mut key_1 = [0u8; 16]; - key_0.copy_from_slice(&key[0..8]); - key_1.copy_from_slice(&key[8..16]); - key_2.copy_from_slice(&key[16..24]); - key_3.copy_from_slice(&key[24..32]); + key_0.copy_from_slice(&key[0..16]); + key_1.copy_from_slice(&key[16..32]); let counter = [u64::from_le_bytes(counter), u64::from_le_bytes(nonce)]; - let key = [ - u64::from_le_bytes(key_0), - u64::from_le_bytes(key_1), - u64::from_le_bytes(key_2), - u64::from_le_bytes(key_3), - ]; + let key = [u128::from_le_bytes(key_0), u128::from_le_bytes(key_1)]; let round_keys = aes256_key_expansion(key); @@ -350,8 +337,7 @@ impl Aes256Ctr64 { self.counter.get()[0] } - #[cfg_attr(target_feature = "zkne", inline(always))] - #[cfg_attr(not(target_feature = "zkne"), target_feature(enable = "zkne"))] + #[inline(always)] pub(crate) unsafe fn next_impl(&self) -> u128 { // Increment the lower 64 bits. let counter = self.counter.get(); @@ -360,65 +346,80 @@ impl Aes256Ctr64 { self.counter.set(new_counter); let round_keys = self.round_keys.get(); - - // We apply the AES encryption on the counter. - let mut state = [counter[0] ^ round_keys[0][0], counter[1] ^ round_keys[0][1]]; - - let mut temp0 = aes64esm(state[0], state[1]); - let mut temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[1][0], temp1 ^ round_keys[1][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[2][0], temp1 ^ round_keys[2][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[3][0], temp1 ^ round_keys[3][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[4][0], temp1 ^ round_keys[4][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[5][0], temp1 ^ round_keys[5][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[6][0], temp1 ^ round_keys[6][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[7][0], temp1 ^ round_keys[7][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[8][0], temp1 ^ round_keys[8][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[9][0], temp1 ^ round_keys[9][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[10][0], temp1 ^ round_keys[10][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[11][0], temp1 ^ round_keys[11][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[12][0], temp1 ^ round_keys[12][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[13][0], temp1 ^ round_keys[13][1]]; - - temp0 = aes64es(state[0], state[1]); - temp1 = aes64es(state[1], state[0]); - state = [temp0 ^ round_keys[14][0], temp1 ^ round_keys[14][1]]; + let mut round_keys_ptr = (&round_keys).as_ptr(); + + // Initialize the state with the counter. + let mut state = counter; + let state_ptr = (&mut state).as_mut_ptr(); + + asm!( + "vsetivli x0, 4, e32, m1, ta, ma", + "vle32.v v0, (t0)", // Load counter into a register + "vle32.v v1, (t1)", // Copy all round keys into the vector registers + "addi t1, t1, 16", + "vle32.v v2, (t1)", + "addi t1, t1, 16", + "vle32.v v3, (t1)", + "addi t1, t1, 16", + "vle32.v v4, (t1)", + "addi t1, t1, 16", + "vle32.v v5, (t1)", + "addi t1, t1, 16", + "vle32.v v6, (t1)", + "addi t1, t1, 16", + "vle32.v v7, (t1)", + "addi t1, t1, 16", + "vle32.v v8, (t1)", + "addi t1, t1, 16", + "vle32.v v9, (t1)", + "addi t1, t1, 16", + "vle32.v v10, (t1)", + "addi t1, t1, 16", + "vle32.v v11, (t1)", + "addi t1, t1, 16", + "vle32.v v12, (t1)", + "addi t1, t1, 16", + "vle32.v v13, (t1)", + "addi t1, t1, 16", + "vle32.v v14, (t1)", + "addi t1, t1, 16", + "vle32.v v15, (t1)", + "vaesz.vs v0, v1", // Whiten the counter + "vaesem.vs v0, v2", // Apply 14 rounds of AES + "vaesem.vs v0, v3", + "vaesem.vs v0, v4", + "vaesem.vs v0, v5", + "vaesem.vs v0, v6", + "vaesem.vs v0, v7", + "vaesem.vs v0, v8", + "vaesem.vs v0, v9", + "vaesem.vs v0, v10", + "vaesem.vs v0, v11", + "vaesem.vs v0, v12", + "vaesem.vs v0, v13", + "vaesem.vs v0, v14", + "vaesef.vs v0, v15", + "vse32.v v0, (t0)", // Store the state + options(nostack), + in("t0") state_ptr, + inlateout("t1") round_keys_ptr, + out("v0") _, + out("v1") _, + out("v2") _, + out("v3") _, + out("v4") _, + out("v5") _, + out("v6") _, + out("v7") _, + out("v8") _, + out("v9") _, + out("v10") _, + out("v11") _, + out("v12") _, + out("v13") _, + out("v14") _, + out("v15") _, + ); // Return the encrypted counter as u128. u128::from(state[0]) | (u128::from(state[1]) << 64) @@ -432,13 +433,13 @@ impl Aes256Ctr64 { #[derive(Clone)] pub struct Aes256Ctr128 { counter: Cell, - round_keys: Cell<[[u64; 2]; AES256_KEY_COUNT]>, + round_keys: Cell<[u128; AES256_KEY_COUNT]>, } impl Drop for Aes256Ctr128 { fn drop(&mut self) { self.counter.set(0); - self.round_keys.set([[0; 2]; AES256_KEY_COUNT]); + self.round_keys.set([0; AES256_KEY_COUNT]); core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst); } } @@ -456,25 +457,15 @@ impl Aes256Ctr128 { clone } - #[target_feature(enable = "zkne")] pub(crate) unsafe fn from_seed_impl(key: [u8; 32], counter: [u8; 16]) -> Self { - let mut key_0 = [0u8; 8]; - let mut key_1 = [0u8; 8]; - let mut key_2 = [0u8; 8]; - let mut key_3 = [0u8; 8]; + let mut key_0 = [0u8; 16]; + let mut key_1 = [0u8; 16]; - key_0.copy_from_slice(&key[0..8]); - key_1.copy_from_slice(&key[8..16]); - key_2.copy_from_slice(&key[16..24]); - key_3.copy_from_slice(&key[24..32]); + key_0.copy_from_slice(&key[0..16]); + key_1.copy_from_slice(&key[16..32]); let counter = u128::from_le_bytes(counter); - let key = [ - u64::from_le_bytes(key_0), - u64::from_le_bytes(key_1), - u64::from_le_bytes(key_2), - u64::from_le_bytes(key_3), - ]; + let key = [u128::from_le_bytes(key_0), u128::from_le_bytes(key_1)]; let round_keys = aes256_key_expansion(key); @@ -488,25 +479,15 @@ impl Aes256Ctr128 { self.counter.get() } - #[target_feature(enable = "zkne")] pub(crate) unsafe fn seed_impl(&self, key: [u8; 32], counter: [u8; 16]) { - let mut key_0 = [0u8; 8]; - let mut key_1 = [0u8; 8]; - let mut key_2 = [0u8; 8]; - let mut key_3 = [0u8; 8]; + let mut key_0 = [0u8; 16]; + let mut key_1 = [0u8; 16]; - key_0.copy_from_slice(&key[0..8]); - key_1.copy_from_slice(&key[8..16]); - key_2.copy_from_slice(&key[16..24]); - key_3.copy_from_slice(&key[24..32]); + key_0.copy_from_slice(&key[0..16]); + key_1.copy_from_slice(&key[16..32]); let counter = u128::from_le_bytes(counter); - let key = [ - u64::from_le_bytes(key_0), - u64::from_le_bytes(key_1), - u64::from_le_bytes(key_2), - u64::from_le_bytes(key_3), - ]; + let key = [u128::from_le_bytes(key_0), u128::from_le_bytes(key_1)]; let round_keys = aes256_key_expansion(key); @@ -518,164 +499,204 @@ impl Aes256Ctr128 { true } - #[cfg_attr(target_feature = "zkne", inline(always))] - #[cfg_attr(not(target_feature = "zkne"), target_feature(enable = "zkne"))] + #[inline(always)] pub(crate) unsafe fn next_impl(&self) -> u128 { // Increment the counter. let counter = self.counter.get(); self.counter.set(counter.wrapping_add(1)); let round_keys = self.round_keys.get(); - let counter_low = counter as u64; - let counter_high = (counter >> 64) as u64; - - // We apply the AES encryption on the counter. - let mut state = [ - counter_low ^ round_keys[0][0], - counter_high ^ round_keys[0][1], - ]; - - let mut temp0 = aes64esm(state[0], state[1]); - let mut temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[1][0], temp1 ^ round_keys[1][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[2][0], temp1 ^ round_keys[2][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[3][0], temp1 ^ round_keys[3][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[4][0], temp1 ^ round_keys[4][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[5][0], temp1 ^ round_keys[5][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[6][0], temp1 ^ round_keys[6][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[7][0], temp1 ^ round_keys[7][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[8][0], temp1 ^ round_keys[8][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[9][0], temp1 ^ round_keys[9][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[10][0], temp1 ^ round_keys[10][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[11][0], temp1 ^ round_keys[11][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[12][0], temp1 ^ round_keys[12][1]]; - - temp0 = aes64esm(state[0], state[1]); - temp1 = aes64esm(state[1], state[0]); - state = [temp0 ^ round_keys[13][0], temp1 ^ round_keys[13][1]]; - - temp0 = aes64es(state[0], state[1]); - temp1 = aes64es(state[1], state[0]); - state = [temp0 ^ round_keys[14][0], temp1 ^ round_keys[14][1]]; + let mut round_keys_ptr = (&round_keys).as_ptr(); + + // Initialize the state with the counter. + let mut state = counter; + let state_ptr = (&mut state) as *mut u128; + + asm!( + "vsetivli x0, 4, e32, m1, ta, ma", + "vle32.v v0, (t0)", // Load counter into a register + "vle32.v v1, (t1)", // Copy all round keys into the vector registers + "addi t1, t1, 16", + "vle32.v v2, (t1)", + "addi t1, t1, 16", + "vle32.v v3, (t1)", + "addi t1, t1, 16", + "vle32.v v4, (t1)", + "addi t1, t1, 16", + "vle32.v v5, (t1)", + "addi t1, t1, 16", + "vle32.v v6, (t1)", + "addi t1, t1, 16", + "vle32.v v7, (t1)", + "addi t1, t1, 16", + "vle32.v v8, (t1)", + "addi t1, t1, 16", + "vle32.v v9, (t1)", + "addi t1, t1, 16", + "vle32.v v10, (t1)", + "addi t1, t1, 16", + "vle32.v v11, (t1)", + "addi t1, t1, 16", + "vle32.v v12, (t1)", + "addi t1, t1, 16", + "vle32.v v13, (t1)", + "addi t1, t1, 16", + "vle32.v v14, (t1)", + "addi t1, t1, 16", + "vle32.v v15, (t1)", + "vaesz.vs v0, v1", // Whiten the counter + "vaesem.vs v0, v2", // Apply 14 rounds of AES + "vaesem.vs v0, v3", + "vaesem.vs v0, v4", + "vaesem.vs v0, v5", + "vaesem.vs v0, v6", + "vaesem.vs v0, v7", + "vaesem.vs v0, v8", + "vaesem.vs v0, v9", + "vaesem.vs v0, v10", + "vaesem.vs v0, v11", + "vaesem.vs v0, v12", + "vaesem.vs v0, v13", + "vaesem.vs v0, v14", + "vaesef.vs v0, v15", + "vse32.v v0, (t0)", // Store the state + options(nostack), + in("t0") state_ptr, + inlateout("t1") round_keys_ptr, + out("v0") _, + out("v1") _, + out("v2") _, + out("v3") _, + out("v4") _, + out("v5") _, + out("v6") _, + out("v7") _, + out("v8") _, + out("v9") _, + out("v10") _, + out("v11") _, + out("v12") _, + out("v13") _, + out("v14") _, + out("v15") _, + ); // Return the encrypted counter as u128. - u128::from(state[0]) | (u128::from(state[1]) << 64) + state } } -#[target_feature(enable = "zkne")] -unsafe fn aes128_key_expansion(key: [u64; 2]) -> [[u64; 2]; AES128_KEY_COUNT] { - unsafe fn generate_round_key(expanded_keys: &mut [[u64; 2]]) { - let prev_key = expanded_keys[RNUM as usize]; - - let temp = aes64ks1i::(prev_key[1]); - let rk0 = aes64ks2(temp, prev_key[0]); - let rk1 = aes64ks2(rk0, prev_key[1]); - - expanded_keys[RNUM as usize + 1] = [rk0, rk1]; - } - let mut expanded_keys = [[0u64; 2]; AES128_KEY_COUNT]; - - // Load the initial key. - expanded_keys[0] = [key[0], key[1]]; - - // The actual key expansion. - generate_round_key::<0>(&mut expanded_keys); - generate_round_key::<1>(&mut expanded_keys); - generate_round_key::<2>(&mut expanded_keys); - generate_round_key::<3>(&mut expanded_keys); - generate_round_key::<4>(&mut expanded_keys); - generate_round_key::<5>(&mut expanded_keys); - generate_round_key::<6>(&mut expanded_keys); - generate_round_key::<7>(&mut expanded_keys); - generate_round_key::<8>(&mut expanded_keys); - generate_round_key::<9>(&mut expanded_keys); +unsafe fn aes128_key_expansion(key: u128) -> [u128; AES128_KEY_COUNT] { + let mut expanded_keys = [0u128; AES128_KEY_COUNT]; + let key_ptr = &key as *const u128; + let mut expanded_ptr = (&mut expanded_keys).as_mut_ptr(); + + asm!( + "vsetivli x0, 4, e32, m4, ta, ma", + "vle32.v v0, (t0)", // Load key as state and copy into expanded + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 1", // Round 1 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 2", // Round 2 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 3", // Round 3 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 4", // Round 4 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 5", // Round 5 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 6", // Round 6 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 7", // Round 7 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 8", // Round 8 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 9", // Round 9 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf1.vi v0, v0, 10", // Round 10 + "add t1, t1, 16", + "vse32.v v0, (t1)", + in("t0") key_ptr, + inlateout("t1") expanded_ptr, + options(nostack), + out("v0") _, + ); expanded_keys } -#[target_feature(enable = "zkne")] -unsafe fn aes256_key_expansion(key: [u64; 4]) -> [[u64; 2]; AES256_KEY_COUNT] { - unsafe fn generate_round_keys( - expanded_keys: &mut [[u64; 2]; AES256_KEY_COUNT], - ) { - let prev_key_0 = expanded_keys[RNUM as usize * 2]; - let prev_key_1 = expanded_keys[(RNUM as usize * 2) + 1]; - - let temp = aes64ks1i::(prev_key_1[1]); - - let rk0 = aes64ks2(temp, prev_key_0[0]); - let rk1 = aes64ks2(rk0, prev_key_0[1]); - - expanded_keys[(RNUM as usize * 2) + 2] = [rk0, rk1]; - - if RNUM < 6 { - let temp = aes64ks1i::<0xA>(rk1); - - let rk2 = aes64ks2(temp, prev_key_1[0]); - let rk3 = aes64ks2(rk2, prev_key_1[1]); - - expanded_keys[(RNUM as usize * 2) + 3] = [rk2, rk3]; - } - } - let mut expanded_keys = [[0u64; 2]; AES256_KEY_COUNT]; - - // Load the initial key. - expanded_keys[0] = [key[0], key[1]]; - expanded_keys[1] = [key[2], key[3]]; - - // The actual key expansion. - generate_round_keys::<0>(&mut expanded_keys); - generate_round_keys::<1>(&mut expanded_keys); - generate_round_keys::<2>(&mut expanded_keys); - generate_round_keys::<3>(&mut expanded_keys); - generate_round_keys::<4>(&mut expanded_keys); - generate_round_keys::<5>(&mut expanded_keys); - generate_round_keys::<6>(&mut expanded_keys); +unsafe fn aes256_key_expansion(key: [u128; 2]) -> [u128; AES256_KEY_COUNT] { + let mut expanded_keys = [0u128; AES256_KEY_COUNT]; + let mut key_ptr = &key as *const u128; + let mut expanded_ptr = (&mut expanded_keys).as_mut_ptr(); + + asm!( + "vsetivli x0, 4, e32, m4, ta, ma", + "vle32.v v0, (t0)", + "addi t0, t0, 16", + "vle32.v v4, (t0)", + "vse32.v v0, (t1)", + "add t1, t1, 16", + "vse32.v v4, (t1)", + "vaeskf2.vi v0, v4, 2", // Round 2 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf2.vi v4, v0, 3", // Round 3 + "addi t1, t1, 16", + "vse32.v v4, (t1)", + "vaeskf2.vi v0, v4, 4", // Round 4 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf2.vi v4, v0, 5", // Round 5 + "addi t1, t1, 16", + "vse32.v v4, (t1)", + "vaeskf2.vi v0, v4, 6", // Round 6 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf2.vi v4, v0, 7", // Round 7 + "addi t1, t1, 16", + "vse32.v v4, (t1)", + "vaeskf2.vi v0, v4, 8", // Round 8 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf2.vi v4, v0, 9", // Round 9 + "addi t1, t1, 16", + "vse32.v v4, (t1)", + "vaeskf2.vi v0, v4, 10", // Round 10 + "addi t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf2.vi v4, v0, 11", // Round 11 + "add t1, t1, 16", + "vse32.v v4, (t1)", + "vaeskf2.vi v0, v4, 12", // Round 12 + "add t1, t1, 16", + "vse32.v v0, (t1)", + "vaeskf2.vi v4, v0, 13", // Round 13 + "add t1, t1, 16", + "vse32.v v4, (t1)", + "vaeskf2.vi v0, v4, 14", // Round 14 + "add t1, t1, 16", + "vse32.v v0, (t1)", + inlateout("t0") key_ptr, + inlateout("t1") expanded_ptr, + options(nostack), + out("v0") _, + out("v4") _, + ); expanded_keys } -#[cfg(all( - test, - not(any( - not(all(target_arch = "riscv64", target_feature = "zkne")), - feature = "force_fallback" - )) -))] +#[cfg(all(test, not(feature = "force_fallback")))] mod tests { use super::*; use crate::constants::{AES128_KEY_COUNT, AES128_KEY_SIZE, AES_BLOCK_SIZE}; @@ -684,16 +705,14 @@ mod tests { #[test] fn test_aes128_key_expansion() { aes128_key_expansion_test(|key| { - let mut key_lo = [0u8; 8]; - let mut key_hi = [0u8; 8]; - key_lo.copy_from_slice(&key[0..8]); - key_hi.copy_from_slice(&key[8..16]); - let key = [u64::from_le_bytes(key_lo), u64::from_le_bytes(key_hi)]; + let mut key = [0u8; 16]; + key.copy_from_slice(&key[0..16]); + let key = u128::from_le_bytes(key); - let expanded: [[u64; 2]; AES128_KEY_COUNT] = unsafe { aes128_key_expansion(key) }; + let expanded: [u128; AES128_KEY_COUNT] = unsafe { aes128_key_expansion(key) }; let expanded: [[u8; AES_BLOCK_SIZE]; AES128_KEY_COUNT] = unsafe { core::mem::transmute::< - [[u64; 2]; AES128_KEY_COUNT], + [u128; AES128_KEY_COUNT], [[u8; AES_BLOCK_SIZE]; AES128_KEY_COUNT], >(expanded) }; @@ -704,25 +723,16 @@ mod tests { #[test] fn test_aes256_key_expansion() { aes256_key_expansion_test(|key| { - let mut key_0_lo = [0u8; 8]; - let mut key_0_hi = [0u8; 8]; - let mut key_1_lo = [0u8; 8]; - let mut key_1_hi = [0u8; 8]; - key_0_lo.copy_from_slice(&key[0..8]); - key_0_hi.copy_from_slice(&key[8..16]); - key_1_lo.copy_from_slice(&key[16..24]); - key_1_hi.copy_from_slice(&key[24..32]); - let key = [ - u64::from_le_bytes(key_0_lo), - u64::from_le_bytes(key_0_hi), - u64::from_le_bytes(key_1_lo), - u64::from_le_bytes(key_1_hi), - ]; - - let expanded: [[u64; 2]; AES256_KEY_COUNT] = unsafe { aes256_key_expansion(key) }; + let mut key_0 = [0u8; 16]; + let mut key_1 = [0u8; 16]; + key_0.copy_from_slice(&key[0..16]); + key_1.copy_from_slice(&key[16..32]); + let key = [u128::from_le_bytes(key_0), u128::from_le_bytes(key_1)]; + + let expanded: [u128; AES256_KEY_COUNT] = unsafe { aes256_key_expansion(key) }; let expanded: [[u8; AES_BLOCK_SIZE]; AES256_KEY_COUNT] = unsafe { core::mem::transmute::< - [[u64; 2]; AES256_KEY_COUNT], + [u128; AES256_KEY_COUNT], [[u8; AES_BLOCK_SIZE]; AES256_KEY_COUNT], >(expanded) }; diff --git a/src/implementation.rs b/src/implementation.rs index e47a90e..1e041cb 100644 --- a/src/implementation.rs +++ b/src/implementation.rs @@ -12,7 +12,7 @@ macro_rules! safely_call { target_feature = "sse2", target_feature = "aes", ), - all(target_arch = "riscv64", target_feature = "zkne"), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", @@ -32,7 +32,7 @@ macro_rules! safely_call { target_feature = "sse2", target_feature = "aes", ), - all(target_arch = "riscv64", target_feature = "zkne"), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", diff --git a/src/lib.rs b/src/lib.rs index 6d7a108..ca11df8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,16 +23,19 @@ //! Use the following target features for optimal performance: //! //! - aarch64: `aes` (using the cryptographic extension) -//! - riscv64: `zkne` (using the scalar based cryptography extension) //! - x86_64: `aes` (using AES-NI) //! +//! There is experimental support for the RISC-V vector crypto extension. Please read the README.md +//! for more information how to use it. +//! //! ## Security Note //! -//! While based on well-established cryptographic primitives, this PRNG is not intended for cryptographic key generation -//! or other sensitive cryptographic operations, simply because safe, automatic re-seeding is not provided. We tested its -//! statistical qualities by running versions with reduced rounds against `practrand` and `TESTu01`'s Big Crush. -//! A version with just 3 rounds of AES encryption rounds passes the `practrand` tests with at least 16 TB. -//! `TESTu01`'s Big Crush requires at least 5 rounds to be successfully cleared. AES-128 uses 10 rounds, whereas +//! While based on well-established cryptographic primitives, this PRNG is not intended for +//! cryptographic key generation or other sensitive cryptographic operations, simply because safe, +//! automatic re-seeding is not provided. We tested its statistical qualities by running versions +//! with reduced rounds against `practrand` and `TESTu01`'s Big Crush. A version with just 3 rounds +//! of AES encryption rounds passes the `practrand` tests with at least 16 TB. `TESTu01`'s Big Crush +//! requires at least 5 rounds to be successfully cleared. AES-128 uses 10 rounds, whereas //! AES-256 uses 14 rounds. //! //! ## Parallel Stream Generation @@ -58,7 +61,6 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(feature = "verification", allow(unused))] #![cfg_attr(not(feature = "std"), no_std)] -#![cfg_attr(target_arch = "riscv64", feature(riscv_ext_intrinsics))] pub mod seeds; @@ -75,9 +77,7 @@ mod traits; target_feature = "sse2", target_feature = "aes", ), - all( - target_arch = "riscv64", target_feature = "zkne" - ), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", @@ -97,9 +97,7 @@ pub(crate) mod fallback; target_feature = "sse2", target_feature = "aes", ), - all( - target_arch = "riscv64", target_feature = "zkne" - ), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", @@ -119,9 +117,7 @@ pub use fallback::{Aes128Ctr128, Aes128Ctr64, Aes256Ctr128, Aes256Ctr64}; target_feature = "sse2", target_feature = "aes", ), - all( - target_arch = "riscv64", target_feature = "zkne" - ), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", @@ -142,9 +138,7 @@ pub use hardware::{Aes128Ctr128, Aes128Ctr64, Aes256Ctr128, Aes256Ctr64}; target_feature = "sse2", target_feature = "aes", ), - all( - target_arch = "riscv64", target_feature = "zkne" - ), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", diff --git a/src/tls.rs b/src/tls.rs index e0ba00f..83f68ad 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -17,7 +17,7 @@ use crate::Random; target_feature = "sse2", target_feature = "aes", ), - all(target_arch = "riscv64", target_feature = "zkne"), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", @@ -37,7 +37,7 @@ thread_local! { target_feature = "sse2", target_feature = "aes", ), - all(target_arch = "riscv64", target_feature = "zkne"), + all(target_arch = "riscv64", feature = "experimental_riscv"), all( target_arch = "aarch64", target_feature = "neon", diff --git a/src/verification.rs b/src/verification.rs index 9c82fd4..5428fe2 100644 --- a/src/verification.rs +++ b/src/verification.rs @@ -60,7 +60,9 @@ fn verify_aes128_ctr64(key: [u8; AES128_KEY_SIZE], iv: [u8; AES_BLOCK_SIZE]) { let hardware = unsafe { Aes128Ctr64Hardware::from_seed_impl(key, nonce, ctr) }; for _ in 0..u8::MAX { - assert_eq!(software.next_impl(), unsafe { hardware.next_impl() }); + assert_eq!(software.next_impl().to_le_bytes(), unsafe { + hardware.next_impl().to_le_bytes() + }); } } @@ -69,7 +71,9 @@ fn verify_aes128_ctr128(key: [u8; AES128_KEY_SIZE], iv: [u8; AES_BLOCK_SIZE]) { let hardware = unsafe { Aes128Ctr128Hardware::from_seed_impl(key, iv) }; for _ in 0..u8::MAX { - assert_eq!(software.next_impl(), unsafe { hardware.next_impl() }); + assert_eq!(software.next_impl().to_le_bytes(), unsafe { + hardware.next_impl().to_le_bytes() + }); } } @@ -83,7 +87,9 @@ fn verify_aes256_ctr64(key: [u8; AES256_KEY_SIZE], iv: [u8; AES_BLOCK_SIZE]) { let hardware = unsafe { Aes256Ctr64Hardware::from_seed_impl(key, nonce, ctr) }; for _ in 0..u8::MAX { - assert_eq!(software.next_impl(), unsafe { hardware.next_impl() }); + assert_eq!(software.next_impl().to_le_bytes(), unsafe { + hardware.next_impl().to_le_bytes() + }); } } @@ -92,6 +98,8 @@ fn verify_aes256_ctr128(key: [u8; AES256_KEY_SIZE], iv: [u8; AES_BLOCK_SIZE]) { let hardware = unsafe { Aes256Ctr128Hardware::from_seed_impl(key, iv) }; for _ in 0..u8::MAX { - assert_eq!(software.next_impl(), unsafe { hardware.next_impl() }); + assert_eq!(software.next_impl().to_le_bytes(), unsafe { + hardware.next_impl().to_le_bytes() + }); } } diff --git a/verification/Cargo.toml b/verification/Cargo.toml index 8dc84e5..b71929f 100644 --- a/verification/Cargo.toml +++ b/verification/Cargo.toml @@ -4,8 +4,13 @@ version = "0.1.0" edition = "2021" publish = false +[features] +experimental_riscv = ["rand_aes/experimental_riscv"] +force_fallback = ["rand_aes/force_fallback"] +force_no_runtime_detection = ["rand_aes/force_no_runtime_detection"] + [dependencies] -rand_aes = { path = "..", features = ["verification"] } +rand_aes = { path = "..", features = ["verification"] } [[bin]] name = "verification" diff --git a/verification/src/bin/verification.rs b/verification/src/bin/verification.rs index ff4458f..69b0d79 100644 --- a/verification/src/bin/verification.rs +++ b/verification/src/bin/verification.rs @@ -1,9 +1,7 @@ fn main() { println!("Starting verification"); - unsafe { - rand_aes::verification::run_verification() - }; + unsafe { rand_aes::verification::run_verification() }; println!("Passed verification!"); }