diff --git a/src/stim/mem/simd_bits.test.cc b/src/stim/mem/simd_bits.test.cc index 62b9eb9c9..3080f2163 100644 --- a/src/stim/mem/simd_bits.test.cc +++ b/src/stim/mem/simd_bits.test.cc @@ -164,6 +164,15 @@ TEST_EACH_WORD_SIZE_W(simd_bits, xor_assignment, { } }) +template +void set_bits_from_u64_vector(simd_bits &bits, std::vector &vec) { + for (size_t w = 0; w < bits.num_u64_padded(); w++) { + for (size_t b = 0; b < 64; b++) { + bits[w * 64 + b] |= (vec[w] & (1ULL << b)); + } + } +} + TEST_EACH_WORD_SIZE_W(simd_bits, add_assignment, { simd_bits m0(512); simd_bits m1(512); @@ -228,6 +237,150 @@ TEST_EACH_WORD_SIZE_W(simd_bits, add_assignment, { m0 += m1; ASSERT_EQ(m0[0], 0); ASSERT_EQ(m0[64], 1); + // Test carrying across multiple (>=2) words. + size_t num_bits = 193; + simd_bits add(num_bits); + simd_bits one(num_bits); + for (size_t word = 0; word < add.num_u64_padded() - 1; word++) { + for (size_t k = 0; k < 64; k++) { + add[word * 64 + k] = 1; + } + } + one[0] = 1; + add += one; + // These should all overflow and carries should propagate to the last word. + for (size_t k = 0; k < num_bits - 1; k++) { + ASSERT_EQ(add[k], 0); + } + ASSERT_EQ(add[num_bits - 1], 1); + // From python + std::vector x{ + 0ULL, + 14988672980522980357ULL, + 18446744073709551615ULL, + 18446744073709551615ULL, + 6866573900576593249ULL, + 0ULL, + 18446744073709551615ULL, + 0ULL}; + std::vector y{ + 4413476325400229597ULL, + 0ULL, + 9428810821357656676ULL, + 7863636477302268070ULL, + 0ULL, + 18446744073709551615ULL, + 0ULL, + 15077824728923429555ULL}; + std::vector z{ + 4413476325400229597ULL, + 14988672980522980357ULL, + 9428810821357656675ULL, + 7863636477302268070ULL, + 6866573900576593250ULL, + 18446744073709551615ULL, + 18446744073709551615ULL, + 15077824728923429555ULL}; + simd_bits a(512), b(512), ref(512); + set_bits_from_u64_vector(a, x); + set_bits_from_u64_vector(b, y); + set_bits_from_u64_vector(ref, z); + a += b; + ASSERT_EQ(a, ref); +}) + +template +void set_random_words_to_all_set( + simd_bits &bits, size_t num_bits, std::mt19937_64 &rng, std::uniform_real_distribution &dist_real) { + bits.randomize(num_bits, rng); + size_t max_bit = W; + for (size_t iword = 0; iword < bits.num_simd_words; iword++) { + double r = dist_real(rng); + if (iword == bits.num_simd_words - 1) { + max_bit = num_bits - W * iword; + } + if (r < 1.0 / 3.0) { + double rall = dist_real(rng); + if (rall > 0.5) { + for (size_t k = 0; k < max_bit; k++) { + bits[iword * W + k] = 1; + } + } else { + for (size_t k = 0; k < max_bit; k++) { + bits[iword * W + k] = 0; + } + } + } + } +} + +TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_add_assignment, { + auto rng = INDEPENDENT_TEST_RNG(); + // a + b == b + a + std::uniform_real_distribution dist_real(0, 1); + for (int i = 0; i < 10; i++) { + std::uniform_int_distribution dist_bits(1, 1200); + int num_bits = dist_bits(rng); + simd_bits m1(num_bits), m2(num_bits); + set_random_words_to_all_set(m1, num_bits, rng, dist_real); + set_random_words_to_all_set(m2, num_bits, rng, dist_real); + simd_bits ref1(m1), ref2(m2); + m1 += ref2; + m2 += ref1; + ASSERT_EQ(m1, m2); + } + // (a + 1) + ~a = allset + for (int i = 0; i < 10; i++) { + std::uniform_int_distribution dist_bits(1, 1200); + int num_bits = dist_bits(rng); + simd_bits m1(num_bits); + simd_bits zero(num_bits); + simd_bits one(num_bits); + one[0] = 1; + set_random_words_to_all_set(m1, num_bits, rng, dist_real); + simd_bits m2(m1); + m2.invert_bits(); + m1 += one; + m1 += m2; + ASSERT_EQ(m1, zero); + } + // m1 += x; m1 = ~m1; m1 += x; m1 is unchanged. + for (int i = 0; i < 10; i++) { + std::uniform_int_distribution dist_bits(1, 1200); + int num_bits = dist_bits(rng); + simd_bits m1(num_bits); + m1.randomize(num_bits, rng); + set_random_words_to_all_set(m1, num_bits, rng, dist_real); + simd_bits ref(m1); + simd_bits m2(num_bits); + m1 += m2; + m1.invert_bits(); + m1 += m2; + m1.invert_bits(); + ASSERT_EQ(m1, ref); + } + // a + (b + c) == (a + b) + c + for (int i = 0; i < 10; i++) { + std::uniform_int_distribution dist_bits(1, 1200); + int num_bits = dist_bits(rng); + simd_bits alhs(num_bits); + simd_bits blhs(num_bits); + simd_bits clhs(num_bits); + simd_bits arhs(num_bits); + simd_bits brhs(num_bits); + simd_bits crhs(num_bits); + set_random_words_to_all_set(alhs, num_bits, rng, dist_real); + arhs = alhs; + set_random_words_to_all_set(blhs, num_bits, rng, dist_real); + brhs = blhs; + set_random_words_to_all_set(clhs, num_bits, rng, dist_real); + crhs = clhs; + blhs += clhs; + alhs += blhs; + arhs += brhs; + arhs += crhs; + ASSERT_EQ(alhs, arhs); + } }) TEST_EACH_WORD_SIZE_W(simd_bits, right_shift_assignment, { diff --git a/src/stim/mem/simd_bits_range_ref.inl b/src/stim/mem/simd_bits_range_ref.inl index 46fc514bd..0d6040331 100644 --- a/src/stim/mem/simd_bits_range_ref.inl +++ b/src/stim/mem/simd_bits_range_ref.inl @@ -57,11 +57,12 @@ simd_bits_range_ref simd_bits_range_ref::operator=(const simd_bits_range_r template simd_bits_range_ref simd_bits_range_ref::operator+=(const simd_bits_range_ref other) { size_t num_u64 = num_u64_padded(); - for (size_t w = 0; w < num_u64 - 1; w++) { - u64[w] += other.u64[w]; - u64[w + 1] += (u64[w] < other.u64[w]); + uint64_t carry{0}; + for (size_t w = 0; w < num_u64; w++) { + uint64_t val_before = u64[w]; + u64[w] += other.u64[w] + carry; + carry = u64[w] < val_before || (carry & (val_before == u64[w])); } - u64[num_u64 - 1] += other.u64[num_u64 - 1]; return *this; }