From d44e4b63a296b2a2bf3ad871b80017768bbb69c6 Mon Sep 17 00:00:00 2001 From: Fionn Malone Date: Wed, 16 Aug 2023 16:22:13 -0700 Subject: [PATCH] Add left/right shift and addition to simd_bits. (#603) From #598 added +=, >>= and <<= to simd_bits. It wasn't obvious to me that these could use word level parallelism without using more memory? For example, the shifts could store the relevant carry masks and or these at the end but this would require a temporary of the same size as the simd_bits instance. --- src/stim/mem/simd_bits.h | 6 + src/stim/mem/simd_bits.inl | 20 ++- src/stim/mem/simd_bits.test.cc | 193 +++++++++++++++++++++++ src/stim/mem/simd_bits_range_ref.h | 5 + src/stim/mem/simd_bits_range_ref.inl | 71 ++++++++- src/stim/mem/simd_bits_range_ref.test.cc | 97 ++++++++++++ 6 files changed, 390 insertions(+), 2 deletions(-) diff --git a/src/stim/mem/simd_bits.h b/src/stim/mem/simd_bits.h index f645bae20..cdd1e5f28 100644 --- a/src/stim/mem/simd_bits.h +++ b/src/stim/mem/simd_bits.h @@ -66,6 +66,12 @@ struct simd_bits { // Mask assignment. simd_bits &operator&=(const simd_bits_range_ref other); simd_bits &operator|=(const simd_bits_range_ref other); + // Addition assigment + simd_bits &operator+=(const simd_bits_range_ref other); + // right shift assignment + simd_bits &operator>>=(int offset); + // left shift assignment + simd_bits &operator<<=(int offset); // Swap assignment. simd_bits &swap_with(simd_bits_range_ref other); diff --git a/src/stim/mem/simd_bits.inl b/src/stim/mem/simd_bits.inl index 2cabf31a3..72f418f70 100644 --- a/src/stim/mem/simd_bits.inl +++ b/src/stim/mem/simd_bits.inl @@ -250,6 +250,24 @@ simd_bits &simd_bits::operator|=(const simd_bits_range_ref other) { return *this; } +template +simd_bits &simd_bits::operator+=(const simd_bits_range_ref other) { + simd_bits_range_ref(*this) += other; + return *this; +} + +template +simd_bits &simd_bits::operator>>=(int offset) { + simd_bits_range_ref(*this) >>= offset; + return *this; +} + +template +simd_bits &simd_bits::operator<<=(int offset) { + simd_bits_range_ref(*this) <<= offset; + return *this; +} + template bool simd_bits::not_zero() const { return simd_bits_range_ref(*this).not_zero(); @@ -289,4 +307,4 @@ std::ostream &operator<<(std::ostream &out, const simd_bits m) { return out << simd_bits_range_ref(m); } -} +} // namespace stim diff --git a/src/stim/mem/simd_bits.test.cc b/src/stim/mem/simd_bits.test.cc index 343292fda..6cdfcefed 100644 --- a/src/stim/mem/simd_bits.test.cc +++ b/src/stim/mem/simd_bits.test.cc @@ -14,6 +14,8 @@ #include "stim/mem/simd_bits.h" +#include + #include "gtest/gtest.h" #include "stim/mem/simd_util.h" @@ -160,6 +162,197 @@ TEST_EACH_WORD_SIZE_W(simd_bits, xor_assignment, { } }) +TEST_EACH_WORD_SIZE_W(simd_bits, add_assignment, { + simd_bits m0(512); + simd_bits m1(512); + uint64_t all_set = 0xFFFFFFFFFFFFFFFFULL; + uint64_t on_off = 0x0F0F0F0F0F0F0F0FULL; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + for (size_t k = 0; k < 64; k++) { + if (word % 2 == 0) { + m0[word * 64 + k] = all_set & (1ULL << k); + m1[word * 64 + k] = all_set & (1ULL << k); + } else { + m0[word * 64 + k] = (bool)(on_off & (1ULL << k)); + m1[word * 64 + k] = (bool)(on_off & (1ULL << k)); + } + } + } + m0 += m1; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + if (word % 2 == 0) { + ASSERT_EQ(pattern, 0xFFFFFFFFFFFFFFFEULL); + } else { + ASSERT_EQ(pattern, 0x1E1E1E1E1E1E1E1FULL); + } + } + for (size_t k = 0; k < m0.num_u64_padded() / 2; k++) { + m1.u64[2 * k] = 0ULL; + m1.u64[2 * k + 1] = 0ULL; + } + m0 += m1; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + if (word % 2 == 0) { + ASSERT_EQ(pattern, 0xFFFFFFFFFFFFFFFEULL); + } else { + ASSERT_EQ(pattern, 0x1E1E1E1E1E1E1E1FULL); + } + } + m0.clear(); + m1.clear(); + m1[0] = 1; + for (int i = 0; i < 512; i++) { + m0 += m1; + } + for (size_t k = 0; k < 64; k++) { + if (k == 9) { + ASSERT_EQ(m0[k], 1); + } else { + ASSERT_EQ(m0[k], 0); + } + } + m0.clear(); + for (size_t k = 0; k < 64; k++) { + m0[k] = all_set & (1ULL << k); + } + m0 += m1; + ASSERT_EQ(m0[0], 0); + ASSERT_EQ(m0[64], 1); +}) + +TEST_EACH_WORD_SIZE_W(simd_bits, right_shift_assignment, { + simd_bits m0(512), m1(512); + m0[511] = 1; + m0 >>= 64; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + if (word != m0.num_u64_padded() - 2) { + ASSERT_EQ(pattern, 0ULL); + } else { + ASSERT_EQ(pattern, uint64_t{1} << 63); + } + } + m1 = m0; + m1 >>= 0; + for (size_t k = 0; k < 512; k++) { + ASSERT_EQ(m0[k], m1[k]); + } + m0.clear(); + uint64_t on_off = 0xAAAAAAAAAAAAAAAAULL; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + for (size_t k = 0; k < 64; k++) { + m0[word * 64 + k] = (bool)(on_off & (1ULL << k)); + } + } + m0 >>= 1; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + ASSERT_EQ(pattern, 0x5555555555555555ULL); + } + m0.clear(); + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + for (size_t k = 0; k < 64; k++) { + m0[word * 64 + k] = (bool)(on_off & (1ULL << k)); + } + } + m0 >>= 128; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + if (word < 6) { + ASSERT_EQ(pattern, 0xAAAAAAAAAAAAAAAA); + } else { + ASSERT_EQ(pattern, 0ULL); + } + } +}) + +TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_right_shift_assignment, { + auto rng = SHARED_TEST_RNG(); + for (int i = 0; i < 5; i++) { + std::uniform_int_distribution dist_bits(1, 1200); + int num_bits = dist_bits(rng); + simd_bits m1(num_bits), m2(num_bits); + m1.randomize(num_bits, rng); + m2 = m1; + std::uniform_int_distribution dist_shift(0, (int)m1.num_bits_padded()); + size_t shift = dist_shift(rng); + m1 >>= shift; + for (size_t k = 0; k < m1.num_bits_padded() - shift; k++) { + ASSERT_EQ(m1[k], m2[k + shift]); + } + for (size_t k = m1.num_bits_padded() - shift; k < m1.num_bits_padded(); k++) { + ASSERT_EQ(m1[k], 0); + } + } +}) + +TEST_EACH_WORD_SIZE_W(simd_bits, left_shift_assignment, { + simd_bits m0(512), m1(512); + for (size_t w = 0; w < m0.num_u64_padded(); w++) { + m0.u64[w] = 0xAAAAAAAAAAAAAAAAULL; + } + m0 <<= 1; + m1 = m0; + m1 <<= 0; + for (size_t k = 0; k < 512; k++) { + ASSERT_EQ(m0[k], m1[k]); + } + for (size_t w = 0; w < m0.num_u64_padded(); w++) { + if (w == 0) { + ASSERT_EQ(m0.u64[w], 0x5555555555555554ULL); + } else { + ASSERT_EQ(m0.u64[w], 0x5555555555555555ULL); + } + } + m0 <<= 63; + for (size_t w = 0; w < m0.num_u64_padded(); w++) { + if (w == 0) { + ASSERT_EQ(m0.u64[w], 0ULL); + } else { + ASSERT_EQ(m0.u64[w], 0xAAAAAAAAAAAAAAAAULL); + } + } + m0 <<= 488; + ASSERT_TRUE(!m0.not_zero()); +}) + +TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_left_shift_assignment, { + auto rng = SHARED_TEST_RNG(); + for (int i = 0; i < 5; i++) { + std::uniform_int_distribution dist_bits(1, 1200); + int num_bits = dist_bits(rng); + simd_bits m1(num_bits), m2(num_bits); + m1.randomize(num_bits, rng); + m2 = m1; + std::uniform_int_distribution dist_shift(0, (int)m1.num_bits_padded()); + size_t shift = dist_shift(rng); + m1 <<= shift; + for (size_t k = 0; k < m1.num_bits_padded() - shift; k++) { + ASSERT_EQ(m1[k + shift], m2[k]); + } + for (size_t k = 0; k < shift; k++) { + ASSERT_EQ(m1[k], 0); + } + } +}) + TEST_EACH_WORD_SIZE_W(simd_bits, assignment, { simd_bits m0(512); simd_bits m1(512); diff --git a/src/stim/mem/simd_bits_range_ref.h b/src/stim/mem/simd_bits_range_ref.h index 721b086b9..94ea73973 100644 --- a/src/stim/mem/simd_bits_range_ref.h +++ b/src/stim/mem/simd_bits_range_ref.h @@ -66,6 +66,11 @@ struct simd_bits_range_ref { /// Mask assignment. simd_bits_range_ref operator&=(const simd_bits_range_ref other); simd_bits_range_ref operator|=(const simd_bits_range_ref other); + // Addition assigment + simd_bits_range_ref operator+=(const simd_bits_range_ref other); + // Shift assigment + simd_bits_range_ref operator>>=(int offset); + simd_bits_range_ref operator<<=(int offset); /// Swap assignment. void swap_with(simd_bits_range_ref other); diff --git a/src/stim/mem/simd_bits_range_ref.inl b/src/stim/mem/simd_bits_range_ref.inl index d795dbbf2..4ca657975 100644 --- a/src/stim/mem/simd_bits_range_ref.inl +++ b/src/stim/mem/simd_bits_range_ref.inl @@ -54,6 +54,75 @@ simd_bits_range_ref simd_bits_range_ref::operator=(const simd_bits_range_r return *this; } +template +simd_bits_range_ref simd_bits_range_ref::operator+=(const simd_bits_range_ref other) { + size_t num_u64 = num_u64_padded(); + for (size_t w = 0; w < num_u64 - 1; w++) { + u64[w] += other.u64[w]; + u64[w + 1] += (u64[w] < other.u64[w]); + } + u64[num_u64 - 1] += other.u64[num_u64 - 1]; + return *this; +} + +template +simd_bits_range_ref simd_bits_range_ref::operator>>=(int offset) { + uint64_t incoming_word; + uint64_t cur_word; + if (offset == 0) { + return *this; + } + while (offset >= 64) { + incoming_word = 0ULL; + for (int w = num_u64_padded() - 1; w >= 0; w--) { + cur_word = u64[w]; + u64[w] = incoming_word; + incoming_word = cur_word; + } + offset -= 64; + } + if (offset == 0) { + return *this; + } + incoming_word = 0ULL; + for (int w = num_u64_padded() - 1; w >= 0; w--) { + cur_word = u64[w]; + u64[w] >>= offset; + u64[w] |= incoming_word << (64 - offset); + incoming_word = cur_word & ((uint64_t{1} << offset) - 1); + } + return *this; +} + +template +simd_bits_range_ref simd_bits_range_ref::operator<<=(int offset) { + uint64_t incoming_word; + uint64_t cur_word; + if (offset == 0) { + return *this; + } + while (offset >= 64) { + incoming_word = 0ULL; + for (int w = 0; w < num_u64_padded(); w++) { + cur_word = u64[w]; + u64[w] = incoming_word; + incoming_word = cur_word; + } + offset -= 64; + } + if (offset == 0) { + return *this; + } + incoming_word = 0ULL; + for (int w = 0; w < num_u64_padded(); w++) { + cur_word = u64[w]; + u64[w] <<= offset; + u64[w] |= incoming_word; + incoming_word = (cur_word >> (64 - offset)); + } + return *this; +} + template void simd_bits_range_ref::swap_with(simd_bits_range_ref other) { for_each_word(other, [](bitword &w0, bitword &w1) { @@ -153,4 +222,4 @@ bool simd_bits_range_ref::intersects(const simd_bits_range_ref other) cons return v != 0; } -} +} // namespace stim diff --git a/src/stim/mem/simd_bits_range_ref.test.cc b/src/stim/mem/simd_bits_range_ref.test.cc index 23e9b7153..f23f2ae99 100644 --- a/src/stim/mem/simd_bits_range_ref.test.cc +++ b/src/stim/mem/simd_bits_range_ref.test.cc @@ -162,6 +162,103 @@ TEST_EACH_WORD_SIZE_W(simd_bits_range_ref, equality, { ASSERT_FALSE(m0 != m1); }) +TEST_EACH_WORD_SIZE_W(simd_bits_range_ref, add_assignment, { + alignas(64) std::array data{ + 0xFFFFFFFFFFFFFFFFULL, + 0x0F0F0F0F0F0F0F0FULL, + 0xFFFFFFFFFFFFFFFFULL, + 0x0F0F0F0F0F0F0F0FULL, + 0xFFFFFFFFFFFFFFFFULL, + 0x0F0F0F0F0F0F0F0FULL, + 0xFFFFFFFFFFFFFFFFULL, + 0x0F0F0F0F0F0F0F0FULL}; + simd_bits_range_ref m0((bitword *)&data[0], sizeof(data) / sizeof(bitword) / 2); + simd_bits_range_ref m1((bitword *)&data[4], sizeof(data) / sizeof(bitword) / 2); + m0 += m1; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + if (word % 2 == 0) { + ASSERT_EQ(pattern, 0xFFFFFFFFFFFFFFFEULL); + } else { + ASSERT_EQ(pattern, 0x1E1E1E1E1E1E1E1FULL); + } + } +}) + +TEST_EACH_WORD_SIZE_W(simd_bits_range_ref, right_shift_assignment, { + alignas(64) std::array data{ + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + }; + simd_bits_range_ref m0((bitword *)&data[0], sizeof(data) / sizeof(bitword)); + m0 >>= 1; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + ASSERT_EQ(pattern, 0x5555555555555555ULL); + } + m0 >>= 511; + ASSERT_TRUE(!m0.not_zero()); +}) + +TEST_EACH_WORD_SIZE_W(simd_bits_range_ref, left_shift_assignment, { + alignas(64) std::array data{ + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + 0xAAAAAAAAAAAAAAAAULL, + }; + simd_bits_range_ref m0((bitword *)&data[0], sizeof(data) / sizeof(bitword)); + m0 <<= 1; + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + if (word == 0) { + ASSERT_EQ(pattern, 0x5555555555555554ULL); + } else { + ASSERT_EQ(pattern, 0x5555555555555555ULL); + } + } + m0 <<= 63; + for (size_t w = 0; w < m0.num_u64_padded(); w++) { + if (w == 0) { + ASSERT_EQ(m0.u64[w], 0ULL); + } else { + ASSERT_EQ(m0.u64[w], 0xAAAAAAAAAAAAAAAAULL); + } + } + for (size_t word = 0; word < m0.num_u64_padded(); word++) { + uint64_t pattern = 0ULL; + for (size_t k = 0; k < 64; k++) { + pattern |= (uint64_t{m0[word * 64 + k]} << k); + } + if (word == 0) { + ASSERT_EQ(pattern, 0ULL); + } else { + ASSERT_EQ(pattern, 0xAAAAAAAAAAAAAAAAULL); + } + } + m0 <<= 488; + ASSERT_TRUE(!m0.not_zero()); +}) + TEST_EACH_WORD_SIZE_W(simd_bits_range_ref, swap_with, { alignas(64) std::array data{}; simd_bits_range_ref m0((bitword *)&data[0], sizeof(data) / sizeof(bitword) / 4);