Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pauli iter all #598

Closed
wants to merge 11 commits into from
Closed

Pauli iter all #598

wants to merge 11 commits into from

Conversation

fdmalone
Copy link
Contributor

@fdmalone fdmalone commented Jul 21, 2023

Adds stim.PauliString.iter_all method.

Following the discussion in #397 the algorithm is a combination of finding the next lexicographically ordered permutation of w bits followed by iterating over all 3^w PauliStrings given this permutation of the qubit labels. For the first part I modified the bit twiddle algorithm to account for multiple words, for the second part I just iterate over 3^w integers and map this to a Pauli using the ternary representation of the integer. This was a bit trickier to get right than I expected, plenty of edge cases.

The modifications for the bit twiddle algorithm are quite specific and I was a bit torn between adding more general operators to simd_bits (like left / right shift + subtraction) which would be cleaner, and what I did, which is a bit clunky.

There are also a few optimizations that could be made which I haven't:

  1. for small w I could avoid the repeated loop over 3^w and store the Pauli strings in the first pass, rather than repeating this work for each of the num_qubits C w permutations (combinations))
  2. some bitwise algorithms are suboptimal (counting trailing zeros) and/or may already exist in stim but I couldn't find them. Happy to change these.
  3. The decoded bit labels / locations could be stored rather than decode each time.

Draft for the moment until:

  • Fix pybind issue.
  • Profile against naive python implmentation.
  • Check docstrings / algorithm description

Copy link
Collaborator

@Strilanc Strilanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good so far. A few ideas commented.

doc/python_api_reference_vDev.md Outdated Show resolved Hide resolved
src/stim/py/numpy.pybind.h Outdated Show resolved Hide resolved
pybind11::arg("num_qubits"),
pybind11::kw_only(),
pybind11::arg("min_weight") = pybind11::none(),
pybind11::arg("max_weight") = pybind11::none(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A useful argument here would be allowed_paulis: str, so the user could restrict to X errors (allowed_paulis="X") or not-Y-errors (allowed_paulis="XZ").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still need to add this, and also the random selection of signs/phases.

src/stim/stabilizers/pauli_string.pybind.cc Outdated Show resolved Hide resolved
src/stim/stabilizers/pauli_string_iter.inl Outdated Show resolved Hide resolved
src/stim/stabilizers/pauli_string_iter.inl Outdated Show resolved Hide resolved
src/stim/stabilizers/pauli_string_iter.perf.cc Outdated Show resolved Hide resolved
Comment on lines 7922 to 7923
) -> None:
"""Seed the iterator with a given qubit pattern.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear to me what this means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually going to remove this function, after closer thought it's not super helpful. The idea was to be able to "seed" the iterator at a specific bit pattern which may be difficult to reach if w was large, but this can be tested on the C++ side more easily, and not sure if it's actually useful on the python side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually scratch that, I needed it for testing random starting points for long strings with higher weight. I will change the name to something more descriptive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this to set_current_permutation.

@fdmalone fdmalone marked this pull request as ready for review July 21, 2023 23:00
@fdmalone
Copy link
Contributor Author

Sorry, I forced pushed when it was still a draft to clean up the many doc related fixes. Should be ready for review now sans profiling.

@Strilanc
Copy link
Collaborator

I was a bit torn between adding more general operators to simd_bits (like left / right shift + subtraction) which would be cleaner, and what I did, which is a bit clunky

I'd be fine with those additions to simd_bits.

@Strilanc
Copy link
Collaborator

Strilanc commented Jul 22, 2023

Here's bit twiddle python code that generates all the requested pauli strings of a given weight, assuming length<=32.

from typing import Iterable


def bits_to_paulis(x: int, n: int) -> str:
    bits = bin(x)[2:].rjust(n*2, '0')[-n*2:]
    return ''.join('_XYZ'[int(bits[k:k+2], 2)] for k in range(0, len(bits), 2))


def count_trailing_zeros(x: int) -> int:
    t = 0
    while not (x & 1):
        x >>= 1
        t += 1
    return t


def masked_increment(x: int, mask: int) -> int:
    return ((x | ~mask) + 1) & mask


def next_bitstring_of_same_hamming_weight(x: int) -> int:
    c1 = x | (x - 1)
    c2 = c1 + 1
    c3 = (~c1 & -~c1) - 1
    c4 = c3 >> (count_trailing_zeros(x) + 1)
    return c2 | c4


def iter_pauli_strings(weight: int, length: int) -> Iterable[str]:
    hamming_mask = (1 << weight) - 1
    while hamming_mask < 2**length:
        # Spread out bits into pairs.
        h = hamming_mask
        mh = 0b0000000000000000000000000000000011111111111111110000000000000000
        ml = 0b0000000000000000000000000000000000000000000000001111111111111111
        h = (h & ml) | ((h & mh) << 16)
        mh = 0b0000000000000000111111110000000000000000000000001111111100000000
        ml = 0b0000000000000000000000001111111100000000000000000000000011111111
        h = (h & ml) | ((h & mh) << 8)
        mh = 0b0000000011110000000000001111000000000000111100000000000011110000
        ml = 0b0000000000001111000000000000111100000000000011110000000000001111
        h = (h & ml) | ((h & mh) << 4)
        mh = 0b0000110000001100000011000000110000001100000011000000110000001100
        ml = 0b0000001100000011000000110000001100000011000000110000001100000011
        h = (h & ml) | ((h & mh) << 2)
        mh = 0b0010001000100010001000100010001000100010001000100010001000100010
        ml = 0b0001000100010001000100010001000100010001000100010001000100010001
        h = (h & ml) | ((h & mh) << 1)
        h |= h << 1

        # Iterate over non-00 values of masked bit pairs, with 00 elsewhere
        xz_mask = 0
        for _ in range(3**weight):
            xz_mask = masked_increment(xz_mask, h)
            xz_mask |= ~(xz_mask | (xz_mask >> 1)) & 0b0101010101010101010101010101010101010101010101010101010101010101
            xz_mask &= h
            yield bits_to_paulis(xz_mask, length)

        # Next mask
        hamming_mask = next_bitstring_of_same_hamming_weight(hamming_mask)


t = 0
for e in iter_pauli_strings(weight=3, length=20):
    print(e)
    t += 1
print(t)

@Strilanc
Copy link
Collaborator

A better one (in particular see pair_sat_increment):

import math
from typing import Iterable, Tuple


def bits_to_paulis(x: int, z: int, n: int) -> str:
    s = ''
    for k in range(n):
        s += '_XYZ'[(x & 1) + (z & 1) * 2]
        x >>= 1
        z >>= 1
    return s[::-1]


def count_trailing_zeros(x: int) -> int:
    t = 0
    while not (x & 1):
        x >>= 1
        t += 1
    return t


def masked_increment(x: int, mask: int) -> int:
    return ((x | ~mask) + 1) & mask


def next_bitstring_of_same_hamming_weight(x: int) -> int:
    c1 = x | (x - 1)
    c3 = ((c1 + 1) & ~c1) - 1
    c4 = c3 >> (count_trailing_zeros(x) + 1)
    return (c1 + 1) | c4


def pair_sat_increment(x: int, z: int, m: int) -> Tuple[int, int]:
    """Finds the next (x, z) such that x | z == m."""
    inc = x & z
    up = ~inc
    inc |= ~m
    inc += 1
    inc &= m
    up &= inc
    z &= inc | ~x
    z ^= x & up
    x ^= up
    return x, z


def iter_pauli_strings(weight: int, length: int) -> Iterable[str]:
    h = (1 << weight) - 1
    while h < 2**length:
        x, z = h, 0
        for _ in range(3**weight):
            yield bits_to_paulis(x, z, length)
            x, z = pair_sat_increment(x, z, h)
        h = next_bitstring_of_same_hamming_weight(h)


t = 0
seen = set()
for e in iter_pauli_strings(weight=4, length=10):
    assert e not in seen, e
    seen.add(e)
    print(e)
    t += 1
print(t, 3**4 * math.factorial(10) // math.factorial(6) // math.factorial(4))

@fdmalone
Copy link
Contributor Author

Ok, I can replace my ternary iteration with the bit twiddle above.

@fdmalone
Copy link
Contributor Author

fdmalone commented Jul 25, 2023

In particular, I will add additional ops to simd_bits (left/right shift and an adder) and replace my ternary iteration with the bit twiddle. I may separate out the new operations in a separate PR.

@Strilanc
Copy link
Collaborator

SGTM

Strilanc pushed a commit that referenced this pull request Aug 16, 2023
From #598 added +=, >>= and <<= to simd_bits. It wasn't obvious to me that these could use word level parallelism without using more memory? For example, the shifts could store the relevant carry masks and or these at the end but this would require a temporary of the same size as the simd_bits instance.
@fdmalone
Copy link
Contributor Author

fdmalone commented Sep 5, 2023

The windows failures seem to be due to a memory error

          pauli_it = list(
              stim.PauliString.iter_all(
  >               num_qubits, min_weight=min_weight, max_weight=max_weight
              )
          )
  E       MemoryError

The test_iter_all_random_permutation tests also seems to trigger test failures on windows periodically on win32 platforms. Not sure what's going on there.

Strilanc pushed a commit that referenced this pull request Sep 11, 2023
Caught this when trying to address #598.
Copy link
Collaborator

@Strilanc Strilanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the major tasks left are debugging the windows crash and adding pytest unit tests using the python side of the API.

// x ^= up
result.xs ^= up;
cur_k++;
return true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a while loop if it always exits?


template <size_t W>
bool PauliStringIterator<W>::pair_sat_increment() {
// This will overflow for large cur_w.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overflow is bad or good here?

Copy link
Contributor Author

@fdmalone fdmalone Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad. I think I should assert on max_weight <= 40 on the pybind side? Or catch the wrapping and exit? Or just use the too large value since it would be pretty hard to iterate through 10^19 values.

template <size_t W>
bool PauliStringIterator<W>::pair_sat_increment() {
// This will overflow for large cur_w.
size_t num_terms = static_cast<size_t>(pow(3, cur_w));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this actually returns the right answer for all relevant values? Doubles have 53 bits of precision; not enough for a 64 bit integer. It'd be safer to have a method like

uint64_t pow3(uint64_t p) {
  assert(p < 41); 
  uint64_t r = 1;
  if (p & 1) r *= 3
  if (p & 2) r *= 9;
  if (p & 4) r *= 81;
  if (p & 8) r *= 6561;
  if (p & 16) r *= 43046721;
  if (p & 32) r *= 1853020188851841;
  return r;
}

simd_bits<W> one(cur_perm.num_bits_padded());
one.u64[0] = uint64_t{1};
// inc += 1
inc += one;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a ++x method to simd_bits?

Copy link
Contributor Author

@fdmalone fdmalone Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it probably makes sense to add -- / -= too. For example, there are a few awkward parts where I'm doing (1 << sum_number_greather_than_64) - 1 in a clunky way.

@fdmalone
Copy link
Contributor Author

I'm going to close this as I won't have time to come back to it for another month or so.

@fdmalone fdmalone closed this Oct 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants