diff --git a/Cargo.lock b/Cargo.lock index 4f0507a7..f9df4233 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,7 +28,8 @@ name = "chacha20" version = "0.10.0-pre.1" dependencies = [ "cfg-if", - "cipher", + "chacha20 0.7.3", + "cipher 0.5.0-pre.4", "cpufeatures", "hex-literal", "rand_chacha", @@ -70,6 +71,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.13" @@ -85,7 +96,7 @@ dependencies = [ name = "hc-256" version = "0.6.0-pre" dependencies = [ - "cipher", + "cipher 0.5.0-pre.4", "hex-literal", ] @@ -154,7 +165,7 @@ dependencies = [ name = "rabbit" version = "0.5.0-pre" dependencies = [ - "cipher", + "cipher 0.5.0-pre.4", "hex-literal", ] @@ -190,7 +201,7 @@ dependencies = [ name = "rc4" version = "0.2.0-pre" dependencies = [ - "cipher", + "cipher 0.5.0-pre.4", "hex-literal", ] @@ -205,7 +216,7 @@ name = "salsa20" version = "0.11.0-pre.1" dependencies = [ "cfg-if", - "cipher", + "cipher 0.5.0-pre.4", "hex-literal", ] @@ -263,6 +274,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/chacha20/src/backends/avx2.rs b/chacha20/src/backends/avx2.rs index d0f05d12..3628d818 100644 --- a/chacha20/src/backends/avx2.rs +++ b/chacha20/src/backends/avx2.rs @@ -1,8 +1,9 @@ -use crate::Rounds; +use crate::{Rounds, Variant}; use core::marker::PhantomData; +use core::mem::size_of; #[cfg(feature = "rng")] -use crate::{ChaChaCore, Variant}; +use crate::ChaChaCore; #[cfg(feature = "cipher")] use crate::{ @@ -33,10 +34,11 @@ const N: usize = PAR_BLOCKS / 2; #[inline] #[target_feature(enable = "avx2")] #[cfg(feature = "cipher")] -pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) +pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) where R: Rounds, F: StreamClosure, + V: Variant { let state_ptr = state.as_ptr() as *const __m128i; let v = [ @@ -45,21 +47,33 @@ where _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))), ]; let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3))); - c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0)); + if size_of::() == 4 { + c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0)); + } else { + c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0)); + } let mut ctr = [c; N]; for i in 0..N { ctr[i] = c; - c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2)); + if size_of::() == 4 { + c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2)); + } else { + c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2)); + } } - let mut backend = Backend:: { + let mut backend = Backend:: { v, ctr, _pd: PhantomData, + _variant: PhantomData }; f.call(&mut backend); state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32; + if size_of::() != 4 { + state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32; + } } #[inline] @@ -83,10 +97,11 @@ where ctr[i] = c; c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2)); } - let mut backend = Backend:: { + let mut backend = Backend:: { v, ctr, _pd: PhantomData, + _variant: PhantomData }; backend.rng_gen_par_ks_blocks(buffer); @@ -94,30 +109,35 @@ where core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32; } -struct Backend { +struct Backend { v: [__m256i; 3], ctr: [__m256i; N], _pd: PhantomData, + _variant: PhantomData } #[cfg(feature = "cipher")] -impl BlockSizeUser for Backend { +impl BlockSizeUser for Backend { type BlockSize = U64; } #[cfg(feature = "cipher")] -impl ParBlocksSizeUser for Backend { +impl ParBlocksSizeUser for Backend { type ParBlocksSize = U4; } #[cfg(feature = "cipher")] -impl StreamBackend for Backend { +impl StreamBackend for Backend { #[inline(always)] fn gen_ks_block(&mut self, block: &mut Block) { unsafe { let res = rounds::(&self.v, &self.ctr); for c in self.ctr.iter_mut() { - *c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1)); + if size_of::() == 4 { + *c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1)); + } else { + *c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, 1, 0, 1)); + } } let res0: [__m128i; 8] = core::mem::transmute(res[0]); @@ -136,7 +156,11 @@ impl StreamBackend for Backend { let pb = PAR_BLOCKS as i32; for c in self.ctr.iter_mut() { - *c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb)); + if size_of::() == 4 { + *c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb)); + } else { + *c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64)); + } } let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; @@ -153,7 +177,7 @@ impl StreamBackend for Backend { } #[cfg(feature = "rng")] -impl Backend { +impl Backend { #[inline(always)] fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) { unsafe { diff --git a/chacha20/src/backends/neon.rs b/chacha20/src/backends/neon.rs index a4f0be5c..b661c3d4 100644 --- a/chacha20/src/backends/neon.rs +++ b/chacha20/src/backends/neon.rs @@ -3,11 +3,11 @@ //! Adapted from the Crypto++ `chacha_simd` implementation by Jack Lloyd and //! Jeffrey Walton (public domain). -use crate::{Rounds, STATE_WORDS}; -use core::{arch::aarch64::*, marker::PhantomData}; +use crate::{Rounds, Variant, STATE_WORDS}; +use core::{arch::aarch64::*, marker::PhantomData, mem::size_of}; #[cfg(feature = "rand_core")] -use crate::{ChaChaCore, Variant}; +use crate::ChaChaCore; #[cfg(feature = "cipher")] use crate::chacha::Block; @@ -18,13 +18,14 @@ use cipher::{ BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamBackend, StreamClosure, }; -struct Backend { +struct Backend { state: [uint32x4_t; 4], ctrs: [uint32x4_t; 4], _pd: PhantomData, + _variant: PhantomData, } -impl Backend { +impl Backend { #[inline] unsafe fn new(state: &mut [u32; STATE_WORDS]) -> Self { let state = [ @@ -39,10 +40,11 @@ impl Backend { vld1q_u32([3, 0, 0, 0].as_ptr()), vld1q_u32([4, 0, 0, 0].as_ptr()), ]; - Backend:: { + Backend:: { state, ctrs, _pd: PhantomData, + _variant: PhantomData, } } } @@ -50,16 +52,25 @@ impl Backend { #[inline] #[cfg(feature = "cipher")] #[target_feature(enable = "neon")] -pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) +pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) where R: Rounds, F: StreamClosure, + V: Variant, { - let mut backend = Backend::::new(state); + let mut backend = Backend::::new(state); f.call(&mut backend); - vst1q_u32(state.as_mut_ptr().offset(12), backend.state[3]); + if core::mem::size_of::() == 4 { + // handle 32-bit counter + vst1q_u32(state.as_mut_ptr().offset(12), backend.state[3]); + } else { + vst1q_u64( + state.as_mut_ptr().offset(12) as *mut u64, + vreinterpretq_u64_u32(backend.state[3]), + ); + } } #[inline] @@ -72,7 +83,7 @@ where R: Rounds, V: Variant, { - let mut backend = Backend::::new(&mut core.state); + let mut backend = Backend::::new(&mut core.state); backend.write_par_ks_blocks(buffer); @@ -80,20 +91,25 @@ where } #[cfg(feature = "cipher")] -impl BlockSizeUser for Backend { +impl BlockSizeUser for Backend { type BlockSize = U64; } #[cfg(feature = "cipher")] -impl ParBlocksSizeUser for Backend { +impl ParBlocksSizeUser for Backend { type ParBlocksSize = U4; } -macro_rules! add64 { - ($a:expr, $b:expr) => { - vreinterpretq_u32_u64(vaddq_u64( - vreinterpretq_u64_u32($a), - vreinterpretq_u64_u32($b), - )) +/// Adds a counter row with either 32-bit or 64-bit addition +macro_rules! add_counter { + ($a:expr, $b:expr, $variant:ty) => { + if size_of::<<$variant>::Counter>() == 4 { + vaddq_u32($a, $b) + } else { + vreinterpretq_u32_u64(vaddq_u64( + vreinterpretq_u64_u32($a), + vreinterpretq_u64_u32($b), + )) + } }; } @@ -105,7 +121,7 @@ macro_rules! add_assign_vec { } #[cfg(feature = "cipher")] -impl StreamBackend for Backend { +impl StreamBackend for Backend { #[inline(always)] fn gen_ks_block(&mut self, block: &mut Block) { let state3 = self.state[3]; @@ -113,7 +129,7 @@ impl StreamBackend for Backend { self.gen_par_ks_blocks(&mut par); *block = par[0]; unsafe { - self.state[3] = add64!(state3, vld1q_u32([1, 0, 0, 0].as_ptr())); + self.state[3] = add_counter!(state3, vld1q_u32([1, 0, 0, 0].as_ptr()), V); } } @@ -126,19 +142,19 @@ impl StreamBackend for Backend { self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[0]), + add_counter!(self.state[3], self.ctrs[0], V), ], [ self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[1]), + add_counter!(self.state[3], self.ctrs[1], V), ], [ self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[2]), + add_counter!(self.state[3], self.ctrs[2], V), ], ]; @@ -146,23 +162,39 @@ impl StreamBackend for Backend { double_quarter_round(&mut blocks); } - for block in 0..4 { - // add state to block - for state_row in 0..4 { + // write first block, with no special counter requirements + for state_row in 0..4 { + // add state + add_assign_vec!(blocks[0][state_row], self.state[state_row]); + // write + vst1q_u8( + dest[0] + .as_mut_ptr() + .offset((state_row as isize) << 4 as isize), + vreinterpretq_u8_u32(blocks[0][state_row as usize]), + ); + } + + // write blocks with adjusted counters + for block in 1..4 { + // add state with adjusted counter + for state_row in 0..3 { add_assign_vec!(blocks[block][state_row], self.state[state_row]); } - if block > 0 { - blocks[block][3] = add64!(blocks[block][3], self.ctrs[block - 1]); - } - // write blocks to dest + add_assign_vec!( + blocks[block][3], + add_counter!(self.state[3], self.ctrs[block - 1], V) + ); + + // write for state_row in 0..4 { vst1q_u8( - dest[block].as_mut_ptr().offset(state_row << 4), + dest[block].as_mut_ptr().offset(state_row << 4 as usize), vreinterpretq_u8_u32(blocks[block][state_row as usize]), ); } } - self.state[3] = add64!(self.state[3], self.ctrs[3]); + self.state[3] = add_counter!(self.state[3], self.ctrs[3], V); } } } @@ -188,7 +220,7 @@ macro_rules! extract { }; } -impl Backend { +impl Backend { #[inline(always)] /// Generates `num_blocks` blocks and blindly writes them to `dest_ptr` /// @@ -205,19 +237,19 @@ impl Backend { self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[0]), + add_counter!(self.state[3], self.ctrs[0], V), ], [ self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[1]), + add_counter!(self.state[3], self.ctrs[1], V), ], [ self.state[0], self.state[1], self.state[2], - add64!(self.state[3], self.ctrs[2]), + add_counter!(self.state[3], self.ctrs[2], V), ], ]; @@ -232,7 +264,7 @@ impl Backend { add_assign_vec!(blocks[block][state_row], self.state[state_row]); } if block > 0 { - blocks[block][3] = add64!(blocks[block][3], self.ctrs[block - 1]); + blocks[block][3] = add_counter!(blocks[block][3], self.ctrs[block - 1], V); } // write blocks to buffer for state_row in 0..4 { @@ -243,7 +275,7 @@ impl Backend { } dest_ptr = dest_ptr.add(64); } - self.state[3] = add64!(self.state[3], self.ctrs[3]); + self.state[3] = add_counter!(self.state[3], self.ctrs[3], V); } } diff --git a/chacha20/src/backends/soft.rs b/chacha20/src/backends/soft.rs index 9cd4234f..dab5df0c 100644 --- a/chacha20/src/backends/soft.rs +++ b/chacha20/src/backends/soft.rs @@ -2,6 +2,7 @@ //! intrinsics. use crate::{quarter_round, ChaChaCore, Rounds, Variant, STATE_WORDS}; +use core::mem::size_of; #[cfg(feature = "cipher")] use crate::chacha::Block; @@ -28,7 +29,18 @@ impl<'a, R: Rounds, V: Variant> StreamBackend for Backend<'a, R, V> { #[inline(always)] fn gen_ks_block(&mut self, block: &mut Block) { let res = run_rounds::(&self.0.state); - self.0.state[12] = self.0.state[12].wrapping_add(1); + + if size_of::() == 4 { + self.0.state[12] = self.0.state[12].wrapping_add(1); + } else { + let no_carry = self.0.state[12].checked_add(1); + if let Some(v) = no_carry { + self.0.state[12] = v; + } else { + self.0.state[12] = 0; + self.0.state[13] = self.0.state[13].wrapping_add(1); + } + } for (chunk, val) in block.chunks_exact_mut(4).zip(res.iter()) { chunk.copy_from_slice(&val.to_le_bytes()); diff --git a/chacha20/src/backends/sse2.rs b/chacha20/src/backends/sse2.rs index 748c59c0..48d18162 100644 --- a/chacha20/src/backends/sse2.rs +++ b/chacha20/src/backends/sse2.rs @@ -1,7 +1,7 @@ -use crate::Rounds; +use crate::{Rounds, Variant}; #[cfg(feature = "rng")] -use crate::{ChaChaCore, Variant}; +use crate::ChaChaCore; #[cfg(feature = "cipher")] use crate::{STATE_WORDS, chacha::Block}; @@ -14,6 +14,7 @@ use cipher::{ ParBlocksSizeUser }; use core::marker::PhantomData; +use core::mem::size_of; #[cfg(target_arch = "x86")] use core::arch::x86::*; @@ -23,13 +24,14 @@ use core::arch::x86_64::*; #[inline] #[target_feature(enable = "sse2")] #[cfg(feature = "cipher")] -pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) +pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) where R: Rounds, F: StreamClosure, + V: Variant, { let state_ptr = state.as_ptr() as *const __m128i; - let mut backend = Backend:: { + let mut backend = Backend:: { v: [ _mm_loadu_si128(state_ptr.add(0)), _mm_loadu_si128(state_ptr.add(1)), @@ -37,35 +39,44 @@ where _mm_loadu_si128(state_ptr.add(3)), ], _pd: PhantomData, + _variant: PhantomData }; f.call(&mut backend); state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32; + if size_of::() != 4 { + state[13] = _mm_extract_epi32(backend.v[3], 1) as u32; + } } -struct Backend { +struct Backend { v: [__m128i; 4], _pd: PhantomData, + _variant: PhantomData } #[cfg(feature = "cipher")] -impl BlockSizeUser for Backend { +impl BlockSizeUser for Backend { type BlockSize = U64; } #[cfg(feature = "cipher")] -impl ParBlocksSizeUser for Backend { +impl ParBlocksSizeUser for Backend { type ParBlocksSize = U1; } #[cfg(feature = "cipher")] -impl StreamBackend for Backend { +impl StreamBackend for Backend { #[inline(always)] fn gen_ks_block(&mut self, block: &mut Block) { unsafe { let res = rounds::(&self.v); - self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, 1)); + if size_of::() == 4 { + self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, 1)); + } else { + self.v[3] = _mm_add_epi64(self.v[3], _mm_set_epi64x(0, 1)); + } let block_ptr = block.as_mut_ptr() as *mut __m128i; for i in 0..4 { @@ -84,7 +95,7 @@ where V: Variant { let state_ptr = core.state.as_ptr() as *const __m128i; - let mut backend = Backend:: { + let mut backend = Backend:: { v: [ _mm_loadu_si128(state_ptr.add(0)), _mm_loadu_si128(state_ptr.add(1)), @@ -92,6 +103,7 @@ where _mm_loadu_si128(state_ptr.add(3)), ], _pd: PhantomData, + _variant: PhantomData }; for i in 0..4 { @@ -102,7 +114,7 @@ where } #[cfg(feature = "rng")] -impl Backend { +impl Backend { #[inline(always)] fn gen_ks_block(&mut self, block: &mut [u32]) { unsafe { diff --git a/chacha20/src/lib.rs b/chacha20/src/lib.rs index 60ec60a0..e67496bc 100644 --- a/chacha20/src/lib.rs +++ b/chacha20/src/lib.rs @@ -265,16 +265,17 @@ impl ChaChaCore { #[cfg(feature = "cipher")] impl StreamCipherSeekCore for ChaChaCore { - type Counter = u32; + type Counter = V::Counter; #[inline(always)] fn get_block_pos(&self) -> Self::Counter { - self.state[12] + V::get_block_pos(&self.state[12..V::NONCE_INDEX]) } #[inline(always)] fn set_block_pos(&mut self, pos: Self::Counter) { - self.state[12] = pos + let block_pos_words = V::set_block_pos_helper(pos); + self.state[12..V::NONCE_INDEX].copy_from_slice(block_pos_words.as_ref()) } } @@ -282,8 +283,7 @@ impl StreamCipherSeekCore for ChaChaCore { impl StreamCipherCore for ChaChaCore { #[inline(always)] fn remaining_blocks(&self) -> Option { - let rem = u32::MAX - self.get_block_pos(); - rem.try_into().ok() + V::remaining_blocks(self.get_block_pos()) } fn process_with_backend(&mut self, f: impl cipher::StreamClosure) { @@ -294,21 +294,21 @@ impl StreamCipherCore for ChaChaCore { cfg_if! { if #[cfg(chacha20_force_avx2)] { unsafe { - backends::avx2::inner::(&mut self.state, f); + backends::avx2::inner::(&mut self.state, f); } } else if #[cfg(chacha20_force_sse2)] { unsafe { - backends::sse2::inner::(&mut self.state, f); + backends::sse2::inner::(&mut self.state, f); } } else { let (avx2_token, sse2_token) = self.tokens; if avx2_token.get() { unsafe { - backends::avx2::inner::(&mut self.state, f); + backends::avx2::inner::(&mut self.state, f); } } else if sse2_token.get() { unsafe { - backends::sse2::inner::(&mut self.state, f); + backends::sse2::inner::(&mut self.state, f); } } else { f.call(&mut backends::soft::Backend(self)); @@ -317,7 +317,7 @@ impl StreamCipherCore for ChaChaCore { } } else if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { unsafe { - backends::neon::inner::(&mut self.state, f); + backends::neon::inner::(&mut self.state, f); } } else { f.call(&mut backends::soft::Backend(self)); diff --git a/chacha20/src/rng.rs b/chacha20/src/rng.rs index 05dd2e06..d4923dce 100644 --- a/chacha20/src/rng.rs +++ b/chacha20/src/rng.rs @@ -1101,4 +1101,21 @@ pub(crate) mod tests { assert_eq!(rng1.next_u64(), rng2.next_u64()); } } + + #[test] + fn counter_wrapping() { + let mut rng = ChaChaRng::from_seed([0u8; 32]); + + // get first four blocks and word pos + let mut first_blocks = [0u8; 64 * 4]; + rng.fill_bytes(&mut first_blocks); + let word_pos = rng.get_word_pos(); + + // get first four blocks after wrapping + rng.set_block_pos(u32::MAX); + let mut result = [0u8; 64 * 5]; + rng.fill_bytes(&mut result); + assert_eq!(word_pos, rng.get_word_pos()); + assert_eq!(&first_blocks[0..64 * 4], &result[64..]); + } } diff --git a/chacha20/src/variants.rs b/chacha20/src/variants.rs index 58043a75..d2f50c8a 100644 --- a/chacha20/src/variants.rs +++ b/chacha20/src/variants.rs @@ -2,10 +2,34 @@ //! //! To be revisited for the 64-bit counter. +#[cfg(feature = "cipher")] +/// A trait to restrict the counter for the cipher crate +pub trait VariantCounter: cipher::Counter {} +#[cfg(not(feature = "cipher"))] +pub trait VariantCounter {} + +impl VariantCounter for u32 {} + +#[cfg(feature = "legacy")] +impl VariantCounter for u64 {} + /// A trait that distinguishes some ChaCha variants pub trait Variant: Clone { /// the size of the Nonce in u32s const NONCE_INDEX: usize; + type Counter: VariantCounter; + type CounterWords: AsRef<[u32]>; + + /// Takes a slice of state[12..NONCE_INDEX] to convert it into + /// Self::Counter. + fn get_block_pos(counter_row: &[u32]) -> Self::Counter; + + /// Breaks down the Self::Counter type into a u32 array for setting the + /// block pos. + fn set_block_pos_helper(value: Self::Counter) -> Self::CounterWords; + + /// A helper method for calculating the remaining blocks using these types + fn remaining_blocks(block_pos: Self::Counter) -> Option; } #[derive(Clone)] @@ -13,6 +37,20 @@ pub trait Variant: Clone { pub struct Ietf(); impl Variant for Ietf { const NONCE_INDEX: usize = 13; + type Counter = u32; + type CounterWords = [u32; 1]; + #[inline(always)] + fn get_block_pos(counter_row: &[u32]) -> Self::Counter { + counter_row[0] + } + #[inline(always)] + fn set_block_pos_helper(value: Self::Counter) -> Self::CounterWords { + [value] + } + #[inline(always)] + fn remaining_blocks(block_pos: Self::Counter) -> Option { + (u32::MAX - block_pos).try_into().ok() + } } #[derive(Clone)] @@ -22,4 +60,23 @@ pub struct Legacy(); #[cfg(feature = "legacy")] impl Variant for Legacy { const NONCE_INDEX: usize = 14; + type Counter = u64; + type CounterWords = [u32; 2]; + #[inline(always)] + fn get_block_pos(counter_row: &[u32]) -> Self::Counter { + counter_row[0] as u64 | (u64::from(counter_row[1]) << 32) + } + #[inline(always)] + fn set_block_pos_helper(value: Self::Counter) -> Self::CounterWords { + [value as u32, (value >> 32) as u32] + } + #[inline(always)] + fn remaining_blocks(block_pos: Self::Counter) -> Option { + let remaining = u64::MAX - block_pos; + #[cfg(target_pointer_width = "32")] + if remaining > usize::MAX as u64 { + return None; + } + remaining.try_into().ok() + } } diff --git a/chacha20/tests/mod.rs b/chacha20/tests/mod.rs index 4e4aa33c..21bf8e6f 100644 --- a/chacha20/tests/mod.rs +++ b/chacha20/tests/mod.rs @@ -233,4 +233,77 @@ mod legacy { } } } + + /// Tests the 64-bit counter with a given amount of test blocks + fn legacy_counter_over_u32_max(test: &[u8; N]) { + assert!(N % 64 == 0, "N should be a multiple of 64"); + use cipher::StreamCipherSeekCore; + // using rand_chacha v0.3 because it is already a dev-dependency, and + // it uses a 64-bit counter + use rand_chacha::{ChaCha20Rng as OgRng, rand_core::{RngCore, SeedableRng}}; + let mut cipher = ChaCha20Legacy::new(&[0u8; 32].into(), &LegacyNonce::from([0u8; 8])); + let mut rng = OgRng::from_seed([0u8; 32]); + + let mut expected = test.clone(); + rng.fill_bytes(&mut expected); + let mut result = test.clone(); + cipher.apply_keystream(&mut result); + assert_eq!(expected, result); + + const SEEK_POS: u64 = (u32::MAX - 10) as u64 * 64; + cipher.seek(SEEK_POS); + rng.set_word_pos(SEEK_POS as u128 / 4); + + let pos: u64 = cipher.current_pos(); + assert_eq!(pos, rng.get_word_pos() as u64 * 4); + let block_pos = cipher.get_core().get_block_pos(); + assert!(block_pos < u32::MAX as u64); + // Apply keystream blocks until some point after the u32 boundary + for i in 1..16 { + let starting_block_pos = cipher.get_core().get_block_pos() as i64 - u32::MAX as i64; + let mut expected = test.clone(); + rng.fill_bytes(&mut expected); + let mut result = test.clone(); + cipher.apply_keystream(&mut result); + if expected != result { + let mut index: usize = 0; + let mut expected_u8: u8 = 0; + let mut found_u8: u8 = 0; + for (i, (e, r)) in expected.iter().zip(result.iter()).enumerate() { + if e != r { + index = i; + expected_u8 = *e; + found_u8 = *r; + break; + } + }; + panic!("Index {} did not match;\n iteration: {}\n expected: {} != {}\nstart block pos - u32::MAX: {}", index, i, expected_u8, found_u8, starting_block_pos); + } + let expected_block_pos = block_pos + i * (test.len() / 64) as u64; + assert!(expected_block_pos == cipher.get_core().get_block_pos(), + "Block pos did not increment as expected; Expected block pos: {}\n actual block_pos: {}\n iteration: {}", + expected_block_pos, + cipher.get_core().get_block_pos(), + i + ); + } + // this test assures us that the counter is in fact over u32::MAX, in + // case we change some of the parameters + assert!(cipher.get_core().get_block_pos() > u32::MAX as u64, "The 64-bit counter test did not surpass u32::MAX"); + } + + /// Runs the legacy_64_bit_counter test with different-sized arrays so that + /// both `gen_ks_block` and `gen_par_ks_blocks` are called with varying + /// starting positions. + #[test] + fn legacy_64_bit_counter() { + legacy_counter_over_u32_max(&[0u8; 64 * 1]); + legacy_counter_over_u32_max(&[0u8; 64 * 2]); + legacy_counter_over_u32_max(&[0u8; 64 * 3]); + legacy_counter_over_u32_max(&[0u8; 64 * 4]); + legacy_counter_over_u32_max(&[0u8; 64 * 5]); + legacy_counter_over_u32_max(&[0u8; 64 * 6]); + legacy_counter_over_u32_max(&[0u8; 64 * 7]); + legacy_counter_over_u32_max(&[0u8; 64 * 8]); + } }