Skip to content

Commit

Permalink
Fix bug with simd_bits +=. (#633)
Browse files Browse the repository at this point in the history
Caught this when trying to address #598.
  • Loading branch information
fdmalone authored Sep 11, 2023
1 parent c135a61 commit 0fdddef
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 4 deletions.
153 changes: 153 additions & 0 deletions src/stim/mem/simd_bits.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ TEST_EACH_WORD_SIZE_W(simd_bits, xor_assignment, {
}
})

template <size_t W>
void set_bits_from_u64_vector(simd_bits<W> &bits, std::vector<uint64_t> &vec) {
for (size_t w = 0; w < bits.num_u64_padded(); w++) {
for (size_t b = 0; b < 64; b++) {
bits[w * 64 + b] |= (vec[w] & (1ULL << b));
}
}
}

TEST_EACH_WORD_SIZE_W(simd_bits, add_assignment, {
simd_bits<W> m0(512);
simd_bits<W> m1(512);
Expand Down Expand Up @@ -228,6 +237,150 @@ TEST_EACH_WORD_SIZE_W(simd_bits, add_assignment, {
m0 += m1;
ASSERT_EQ(m0[0], 0);
ASSERT_EQ(m0[64], 1);
// Test carrying across multiple (>=2) words.
size_t num_bits = 193;
simd_bits<W> add(num_bits);
simd_bits<W> one(num_bits);
for (size_t word = 0; word < add.num_u64_padded() - 1; word++) {
for (size_t k = 0; k < 64; k++) {
add[word * 64 + k] = 1;
}
}
one[0] = 1;
add += one;
// These should all overflow and carries should propagate to the last word.
for (size_t k = 0; k < num_bits - 1; k++) {
ASSERT_EQ(add[k], 0);
}
ASSERT_EQ(add[num_bits - 1], 1);
// From python
std::vector<uint64_t> x{
0ULL,
14988672980522980357ULL,
18446744073709551615ULL,
18446744073709551615ULL,
6866573900576593249ULL,
0ULL,
18446744073709551615ULL,
0ULL};
std::vector<uint64_t> y{
4413476325400229597ULL,
0ULL,
9428810821357656676ULL,
7863636477302268070ULL,
0ULL,
18446744073709551615ULL,
0ULL,
15077824728923429555ULL};
std::vector<uint64_t> z{
4413476325400229597ULL,
14988672980522980357ULL,
9428810821357656675ULL,
7863636477302268070ULL,
6866573900576593250ULL,
18446744073709551615ULL,
18446744073709551615ULL,
15077824728923429555ULL};
simd_bits<W> a(512), b(512), ref(512);
set_bits_from_u64_vector(a, x);
set_bits_from_u64_vector(b, y);
set_bits_from_u64_vector(ref, z);
a += b;
ASSERT_EQ(a, ref);
})

template <size_t W>
void set_random_words_to_all_set(
simd_bits<W> &bits, size_t num_bits, std::mt19937_64 &rng, std::uniform_real_distribution<double> &dist_real) {
bits.randomize(num_bits, rng);
size_t max_bit = W;
for (size_t iword = 0; iword < bits.num_simd_words; iword++) {
double r = dist_real(rng);
if (iword == bits.num_simd_words - 1) {
max_bit = num_bits - W * iword;
}
if (r < 1.0 / 3.0) {
double rall = dist_real(rng);
if (rall > 0.5) {
for (size_t k = 0; k < max_bit; k++) {
bits[iword * W + k] = 1;
}
} else {
for (size_t k = 0; k < max_bit; k++) {
bits[iword * W + k] = 0;
}
}
}
}
}

TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_add_assignment, {
auto rng = INDEPENDENT_TEST_RNG();
// a + b == b + a
std::uniform_real_distribution<double> dist_real(0, 1);
for (int i = 0; i < 10; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> m1(num_bits), m2(num_bits);
set_random_words_to_all_set(m1, num_bits, rng, dist_real);
set_random_words_to_all_set(m2, num_bits, rng, dist_real);
simd_bits<W> ref1(m1), ref2(m2);
m1 += ref2;
m2 += ref1;
ASSERT_EQ(m1, m2);
}
// (a + 1) + ~a = allset
for (int i = 0; i < 10; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> m1(num_bits);
simd_bits<W> zero(num_bits);
simd_bits<W> one(num_bits);
one[0] = 1;
set_random_words_to_all_set(m1, num_bits, rng, dist_real);
simd_bits<W> m2(m1);
m2.invert_bits();
m1 += one;
m1 += m2;
ASSERT_EQ(m1, zero);
}
// m1 += x; m1 = ~m1; m1 += x; m1 is unchanged.
for (int i = 0; i < 10; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> m1(num_bits);
m1.randomize(num_bits, rng);
set_random_words_to_all_set(m1, num_bits, rng, dist_real);
simd_bits<W> ref(m1);
simd_bits<W> m2(num_bits);
m1 += m2;
m1.invert_bits();
m1 += m2;
m1.invert_bits();
ASSERT_EQ(m1, ref);
}
// a + (b + c) == (a + b) + c
for (int i = 0; i < 10; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> alhs(num_bits);
simd_bits<W> blhs(num_bits);
simd_bits<W> clhs(num_bits);
simd_bits<W> arhs(num_bits);
simd_bits<W> brhs(num_bits);
simd_bits<W> crhs(num_bits);
set_random_words_to_all_set(alhs, num_bits, rng, dist_real);
arhs = alhs;
set_random_words_to_all_set(blhs, num_bits, rng, dist_real);
brhs = blhs;
set_random_words_to_all_set(clhs, num_bits, rng, dist_real);
crhs = clhs;
blhs += clhs;
alhs += blhs;
arhs += brhs;
arhs += crhs;
ASSERT_EQ(alhs, arhs);
}
})

TEST_EACH_WORD_SIZE_W(simd_bits, right_shift_assignment, {
Expand Down
9 changes: 5 additions & 4 deletions src/stim/mem/simd_bits_range_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ simd_bits_range_ref<W> simd_bits_range_ref<W>::operator=(const simd_bits_range_r
template <size_t W>
simd_bits_range_ref<W> simd_bits_range_ref<W>::operator+=(const simd_bits_range_ref<W> other) {
size_t num_u64 = num_u64_padded();
for (size_t w = 0; w < num_u64 - 1; w++) {
u64[w] += other.u64[w];
u64[w + 1] += (u64[w] < other.u64[w]);
uint64_t carry{0};
for (size_t w = 0; w < num_u64; w++) {
uint64_t val_before = u64[w];
u64[w] += other.u64[w] + carry;
carry = u64[w] < val_before || (carry & (val_before == u64[w]));
}
u64[num_u64 - 1] += other.u64[num_u64 - 1];
return *this;
}

Expand Down

0 comments on commit 0fdddef

Please sign in to comment.