Skip to content

Commit

Permalink
Fix addition and add more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
fdmalone committed Sep 8, 2023
1 parent 8f06a5a commit 4e193f0
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 13 deletions.
95 changes: 83 additions & 12 deletions src/stim/mem/simd_bits.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ TEST_EACH_WORD_SIZE_W(simd_bits, xor_assignment, {
}
})

template <size_t W>
void set_bits_from_u64_vector(simd_bits<W> &bits, std::vector<uint64_t> &vec) {
for (size_t w = 0; w < bits.num_u64_padded(); w++) {
for (size_t b = 0; b < 64; b++) {
bits[w * 64 + b] |= (vec[w] & (1ULL << b));
}
}
}

TEST_EACH_WORD_SIZE_W(simd_bits, add_assignment, {
simd_bits<W> m0(512);
simd_bits<W> m1(512);
Expand Down Expand Up @@ -244,41 +253,103 @@ TEST_EACH_WORD_SIZE_W(simd_bits, add_assignment, {
ASSERT_EQ(add[k], 0);
}
ASSERT_EQ(add[num_bits - 1], 1);
// From python
std::vector<uint64_t> x{
0ULL,
14988672980522980357ULL,
18446744073709551615ULL,
18446744073709551615ULL,
6866573900576593249ULL,
0ULL,
18446744073709551615ULL,
0ULL};
std::vector<uint64_t> y{
4413476325400229597ULL,
0ULL,
9428810821357656676ULL,
7863636477302268070ULL,
0ULL,
18446744073709551615ULL,
0ULL,
15077824728923429555ULL};
std::vector<uint64_t> z{
4413476325400229597ULL,
14988672980522980357ULL,
9428810821357656675ULL,
7863636477302268070ULL,
6866573900576593250ULL,
18446744073709551615ULL,
18446744073709551615ULL,
15077824728923429555ULL};
simd_bits<W> a(512), b(512), ref(512);
set_bits_from_u64_vector(a, x);
set_bits_from_u64_vector(b, y);
set_bits_from_u64_vector(ref, z);
// a += b;
a += b;
ASSERT_EQ(a, ref);
})

template <size_t W>
void set_random_words_to_all_set(
simd_bits<W> &bits, size_t num_bits, std::mt19937_64 &rng, std::uniform_real_distribution<double> &dist_real) {
bits.randomize(num_bits, rng);
size_t max_bit = 64;
for (size_t iword = 0; iword < bits.num_u64_padded(); iword++) {
double r = dist_real(rng);
if (iword == bits.num_u64_padded() - 1) {
max_bit = num_bits - 64 * iword;
}
if (r < 1.0 / 3.0) {
double rall = dist_real(rng);
if (rall > 0.5) {
for (size_t k = 0; k < max_bit; k++) {
bits[iword * 64 + k] = 1;
}
} else {
for (size_t k = 0; k < max_bit; k++) {
bits[iword * 64 + k] = 0;
}
}
}
}
}

TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_add_assignment, {
auto rng = INDEPENDENT_TEST_RNG();
// a + b == b + a
for (int i = 0; i < 10000; i++) {
std::uniform_real_distribution<double> dist_real(0, 1);
for (int i = 0; i < 10; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> m1(num_bits), m2(num_bits);
m1.randomize(num_bits, rng);
m2.randomize(num_bits, rng);
set_random_words_to_all_set(m1, num_bits, rng, dist_real);
set_random_words_to_all_set(m2, num_bits, rng, dist_real);
simd_bits<W> ref1(m1), ref2(m2);
m1 += ref2;
m2 += ref1;
ASSERT_EQ(m1, m2);
}
// a + ~a = allset
for (int i = 0; i < 10000; i++) {
for (int i = 0; i < 10; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> m1(num_bits);
simd_bits<W> allset(num_bits);
allset.invert_bits();
m1.randomize(num_bits, rng);
set_random_words_to_all_set(m1, num_bits, rng, dist_real);
simd_bits<W> m2(m1);
m2.invert_bits();
m1 += m2;
ASSERT_EQ(m1, allset);
}
// m1 += x; m1 = ~x; m1 += x; m1 is unchanged.
for (int i = 0; i < 10000; i++) {
for (int i = 0; i < 10; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> m1(num_bits);
m1.randomize(num_bits, rng);
set_random_words_to_all_set(m1, num_bits, rng, dist_real);
simd_bits<W> ref(m1);
simd_bits<W> m2(num_bits);
m1 += m2;
Expand All @@ -287,21 +358,21 @@ TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_add_assignment, {
m1.invert_bits();
ASSERT_EQ(m1, ref);
}
// a + (b + c) == (a + b) + c
for (int i = 0; i < 10000; i++) {
// // a + (b + c) == (a + b) + c
for (int i = 0; i < 10; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> arhs(num_bits);
simd_bits<W> alhs(num_bits);
simd_bits<W> blhs(num_bits);
simd_bits<W> clhs(num_bits);
simd_bits<W> arhs(num_bits);
simd_bits<W> brhs(num_bits);
simd_bits<W> crhs(num_bits);
alhs.randomize(num_bits, rng);
set_random_words_to_all_set(alhs, num_bits, rng, dist_real);
arhs = alhs;
blhs.randomize(num_bits, rng);
set_random_words_to_all_set(blhs, num_bits, rng, dist_real);
brhs = blhs;
clhs.randomize(num_bits, rng);
set_random_words_to_all_set(clhs, num_bits, rng, dist_real);
crhs = clhs;
blhs += clhs;
alhs += blhs;
Expand Down
2 changes: 1 addition & 1 deletion src/stim/mem/simd_bits_range_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ simd_bits_range_ref<W> simd_bits_range_ref<W>::operator+=(const simd_bits_range_
for (size_t w = 0; w < num_u64; w++) {
uint64_t val_before = u64[w];
u64[w] += other.u64[w] + carry;
carry = u64[w] < val_before;
carry = u64[w] < val_before || (carry & (val_before == u64[w]));
}
return *this;
}
Expand Down

0 comments on commit 4e193f0

Please sign in to comment.