Skip to content

Commit

Permalink
Re-implement "as_array_of_cells" until it's stable.
Browse files Browse the repository at this point in the history
The compiler can optimize the code the same way even if we don't
use this pattern, but using this pattern we write down out intend
more clearly. This pattern seems safe, since "as_slice_of_cells"
is stable and does essentially the same (there are also no open
questions for #88248).
  • Loading branch information
hasenbanck committed Aug 9, 2024
1 parent f2aad93 commit 70e8282
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 136 deletions.
15 changes: 12 additions & 3 deletions src/fallback/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
//! - Fixed: Always uses the software AES implementation.
#[cfg(all(
any(
not(all(feature = "std", any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86"))),
not(all(
feature = "std",
any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")
)),
feature = "force_no_runtime_detection"
),
not(feature = "verification")
Expand All @@ -14,7 +17,10 @@ mod fixed;

#[cfg(all(
not(any(
not(all(feature = "std", any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86"))),
not(all(
feature = "std",
any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")
)),
feature = "force_no_runtime_detection"
)),
not(feature = "verification")
Expand All @@ -33,7 +39,10 @@ pub use runtime::{Aes128Ctr128, Aes128Ctr64, Aes256Ctr128, Aes256Ctr64};

#[cfg(all(
any(
not(all(feature = "std", any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86"))),
not(all(
feature = "std",
any(target_arch = "aarch64", target_arch = "x86_64", target_arch = "x86")
)),
feature = "force_no_runtime_detection"
),
not(feature = "verification")
Expand Down
127 changes: 70 additions & 57 deletions src/hardware/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,30 @@ impl Aes128Ctr64 {
#[cfg_attr(not(target_feature = "neon"), target_feature(enable = "neon"))]
pub(crate) unsafe fn next_impl(&self) -> u128 {
let counter = self.counter.get();
let round_keys = self.round_keys.get();

// Increment the lower 64 bits using SIMD.
let increment = vsetq_lane_u64::<0>(1, vmovq_n_u64(0));
let new_counter = vaddq_u64(counter, increment);
self.counter.set(new_counter);

// SAFETY: `Cell<T>` has the same memory layout as `T`.
// Use `as_array_of_cells` once stable: https://github.com/rust-lang/rust/issues/88248
let rks = &*((&self.round_keys) as *const Cell<[_; AES128_KEY_COUNT]>
as *const [Cell<_>; AES128_KEY_COUNT]);

// We apply the AES encryption on the counter.
let mut state = vreinterpretq_u8_u64(counter);
state = vaesmcq_u8(vaeseq_u8(state, round_keys[0]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[1]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[2]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[3]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[4]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[5]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[6]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[7]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[8]));
state = vaeseq_u8(state, round_keys[9]);
state = veorq_u8(state, round_keys[10]);
state = vaesmcq_u8(vaeseq_u8(state, rks[0].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[1].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[2].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[3].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[4].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[5].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[6].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[7].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[8].get()));
state = vaeseq_u8(state, rks[9].get());
state = veorq_u8(state, rks[10].get());

// Return the encrypted counter as u128.
*(&state as *const uint8x16_t as *const u128)
Expand Down Expand Up @@ -175,21 +179,24 @@ impl Aes128Ctr128 {
let counter = self.counter.get();
self.counter.set(counter.wrapping_add(1));

let round_keys = self.round_keys.get();
// SAFETY: `Cell<T>` has the same memory layout as `T`.
// Use `as_array_of_cells` once stable: https://github.com/rust-lang/rust/issues/88248
let rks = &*((&self.round_keys) as *const Cell<[_; AES128_KEY_COUNT]>
as *const [Cell<_>; AES128_KEY_COUNT]);

// We apply the AES encryption on the whitened counter.
let mut state = vld1q_u8(counter.to_le_bytes().as_ptr().cast());
state = vaesmcq_u8(vaeseq_u8(state, round_keys[0]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[1]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[2]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[3]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[4]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[5]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[6]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[7]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[8]));
state = vaeseq_u8(state, round_keys[9]);
state = veorq_u8(state, round_keys[10]);
state = vaesmcq_u8(vaeseq_u8(state, rks[0].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[1].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[2].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[3].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[4].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[5].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[6].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[7].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[8].get()));
state = vaeseq_u8(state, rks[9].get());
state = veorq_u8(state, rks[10].get());

// Return the encrypted counter as u128.
*(&state as *const uint8x16_t as *const u128)
Expand Down Expand Up @@ -257,30 +264,33 @@ impl Aes256Ctr64 {
#[cfg_attr(not(target_feature = "neon"), target_feature(enable = "neon"))]
pub(crate) unsafe fn next_impl(&self) -> u128 {
let counter = self.counter.get();
let round_keys = self.round_keys.get();

// Increment the lower 64 bits using SIMD.
let increment = vcombine_u64(vdup_n_u64(1), vdup_n_u64(0));
let new_counter = vaddq_u64(counter, increment);
self.counter.set(new_counter);

// SAFETY: `Cell<T>` has the same memory layout as `T`.
// Use `as_array_of_cells` once stable: https://github.com/rust-lang/rust/issues/88248
let rks = &*((&self.round_keys) as *const Cell<[_; AES256_KEY_COUNT]>
as *const [Cell<_>; AES256_KEY_COUNT]);

// We apply the AES encryption on the counter.
let mut state = vreinterpretq_u8_u64(counter);
state = vaesmcq_u8(vaeseq_u8(state, round_keys[0]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[1]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[2]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[3]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[4]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[5]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[6]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[7]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[8]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[9]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[10]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[11]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[12]));
state = vaeseq_u8(state, round_keys[13]);
state = veorq_u8(state, round_keys[14]);
state = vaesmcq_u8(vaeseq_u8(state, rks[0].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[1].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[2].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[3].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[4].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[5].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[6].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[7].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[8].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[9].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[10].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[11].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[12].get()));
state = vaeseq_u8(state, rks[13].get());
state = veorq_u8(state, rks[14].get());

// Return the encrypted counter as u128.
*(&state as *const uint8x16_t as *const u128)
Expand Down Expand Up @@ -356,25 +366,28 @@ impl Aes256Ctr128 {
let counter = self.counter.get();
self.counter.set(counter.wrapping_add(1));

let round_keys = self.round_keys.get();
// SAFETY: `Cell<T>` has the same memory layout as `T`.
// Use `as_array_of_cells` once stable: https://github.com/rust-lang/rust/issues/88248
let rks = &*((&self.round_keys) as *const Cell<[_; AES256_KEY_COUNT]>
as *const [Cell<_>; AES256_KEY_COUNT]);

// We apply the AES encryption on the counter.
let mut state = vld1q_u8(counter.to_le_bytes().as_ptr().cast());
state = vaesmcq_u8(vaeseq_u8(state, round_keys[0]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[1]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[2]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[3]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[4]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[5]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[6]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[7]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[8]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[9]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[10]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[11]));
state = vaesmcq_u8(vaeseq_u8(state, round_keys[12]));
state = vaeseq_u8(state, round_keys[13]);
state = veorq_u8(state, round_keys[14]);
state = vaesmcq_u8(vaeseq_u8(state, rks[0].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[1].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[2].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[3].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[4].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[5].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[6].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[7].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[8].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[9].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[10].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[11].get()));
state = vaesmcq_u8(vaeseq_u8(state, rks[12].get()));
state = vaeseq_u8(state, rks[13].get());
state = veorq_u8(state, rks[14].get());

// Return the encrypted counter as u128.
*(&state as *const uint8x16_t as *const u128)
Expand Down
23 changes: 5 additions & 18 deletions src/hardware/riscv64.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use core::{arch::asm, cell::{Cell, RefCell}};
use core::{
arch::asm,
cell::{Cell, RefCell},
};

use crate::constants::{AES128_KEY_COUNT, AES128_KEY_SIZE, AES256_KEY_COUNT, AES256_KEY_SIZE};

Expand Down Expand Up @@ -72,11 +75,7 @@ impl Aes128Ctr64 {
let mut new_counter = counter;
new_counter[0] = counter[0].wrapping_add(1);
self.counter.set(new_counter);

// We know that there can't be any other reference to its data, and it will also not
// store a reference to it somewhere. So it's safe for the ASM to read from it directly.
// Once there are intrinsic, we can again use the cell type, since then the compiler is
// able to optimize the access to it.

let mut round_keys_ptr = self.round_keys.as_ptr();

// Initialize the state with the counter.
Expand Down Expand Up @@ -214,10 +213,6 @@ impl Aes128Ctr128 {
let counter = self.counter.get();
self.counter.set(counter.wrapping_add(1));

// We know that there can't be any other reference to its data, and it will also not
// store a reference to it somewhere. So it's safe for the ASM to read from it directly.
// Once there are intrinsic, we can again use the cell type, since then the compiler is
// able to optimize the access to it.
let mut round_keys_ptr = self.round_keys.as_ptr();

// Initialize the state with the counter.
Expand Down Expand Up @@ -351,10 +346,6 @@ impl Aes256Ctr64 {
new_counter[0] = counter[0].wrapping_add(1);
self.counter.set(new_counter);

// We know that there can't be any other reference to its data, and it will also not
// store a reference to it somewhere. So it's safe for the ASM to read from it directly.
// Once there are intrinsic, we can again use the cell type, since then the compiler is
// able to optimize the access to it.
let mut round_keys_ptr = self.round_keys.as_ptr();

// Initialize the state with the counter.
Expand Down Expand Up @@ -514,10 +505,6 @@ impl Aes256Ctr128 {
let counter = self.counter.get();
self.counter.set(counter.wrapping_add(1));

// We know that there can't be any other reference to its data, and it will also not
// store a reference to it somewhere. So it's safe for the ASM to read from it directly.
// Once there are intrinsic, we can again use the cell type, since then the compiler is
// able to optimize the access to it.
let mut round_keys_ptr = self.round_keys.as_ptr();

// Initialize the state with the counter.
Expand Down
Loading

0 comments on commit 70e8282

Please sign in to comment.