Skip to content

Commit

Permalink
Speed up stim.Tableau.from_stabilizers another 10x
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Mar 15, 2024
1 parent 4f1d217 commit 533e709
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 27 deletions.
120 changes: 95 additions & 25 deletions src/stim/stabilizers/conversions.inl
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,18 @@ Tableau<W> stabilizers_to_tableau(
num_qubits = std::max(num_qubits, e.num_qubits);
}

simd_bit_table<W> buf_xs(num_qubits, stabilizers.size());
simd_bit_table<W> buf_zs(num_qubits, stabilizers.size());
simd_bits<W> buf_signs(stabilizers.size());
simd_bits<W> buf_workspace(stabilizers.size());
for (size_t k = 0; k < stabilizers.size(); k++) {
memcpy(buf_xs[k].u8, stabilizers[k].xs.u8, stabilizers[k].xs.num_u8_padded());
memcpy(buf_zs[k].u8, stabilizers[k].zs.u8, stabilizers[k].zs.num_u8_padded());
buf_signs[k] = stabilizers[k].sign;
}
buf_xs = buf_xs.transposed();
buf_zs = buf_zs.transposed();

for (size_t k1 = 0; k1 < stabilizers.size(); k1++) {
for (size_t k2 = k1 + 1; k2 < stabilizers.size(); k2++) {
if (!stabilizers[k1].ref().commutes(stabilizers[k2])) {
Expand All @@ -572,38 +584,28 @@ Tableau<W> stabilizers_to_tableau(
}
}
Circuit elimination_instructions;
PauliString<W> buf(num_qubits);

size_t used = 0;
for (const auto &e : stabilizers) {
if (e.num_qubits == num_qubits) {
buf = e;
} else {
buf.xs.clear();
buf.zs.clear();
memcpy(buf.xs.u8, e.xs.u8, e.xs.num_u8_padded());
memcpy(buf.zs.u8, e.zs.u8, e.zs.num_u8_padded());
buf.sign = e.sign;
}
buf.ref().do_circuit(elimination_instructions);

for (size_t k = 0; k < stabilizers.size(); k++) {
// Find a non-identity term in the Pauli string past the region used by other stabilizers.
size_t pivot;
for (pivot = used; pivot < num_qubits; pivot++) {
if (buf.xs[pivot] || buf.zs[pivot]) {
if (buf_xs[pivot][k] || buf_zs[pivot][k]) {
break;
}
}

// Check for incompatible / redundant stabilizers.
if (pivot == num_qubits) {
if (buf.xs.not_zero()) {
throw std::invalid_argument("Some of the given stabilizers anticommute.");
for (size_t q = 0; q < num_qubits; q++) {
if (buf_xs[q][k]) {
throw std::invalid_argument("Some of the given stabilizers anticommute.");
}
}
if (buf.sign) {
if (buf_signs[k]) {
throw std::invalid_argument("Some of the given stabilizers contradict each other.");
}
if (!allow_redundant && buf.zs.not_zero()) {
if (!allow_redundant) {
throw std::invalid_argument(
"Didn't specify allow_redundant=True but one of the given stabilizers is a product of the others. "
"To allow redundant stabilizers, pass the argument allow_redundant=True.");
Expand All @@ -612,21 +614,86 @@ Tableau<W> stabilizers_to_tableau(
}

// Change pivot basis to the Z axis.
if (buf.xs[pivot]) {
GateType g = buf.zs[pivot] ? GateType::H_YZ : GateType::H;
if (buf_xs[pivot][k]) {
GateType g = buf_zs[pivot][k] ? GateType::H_YZ : GateType::H;
GateTarget t = GateTarget::qubit(pivot);
CircuitInstruction instruction{g, {}, &t};
elimination_instructions.safe_append(instruction);
buf.ref().do_instruction(instruction);
size_t q = pivot;
switch (g) {
case GateType::H_YZ:
buf_xs[q] ^= buf_zs[q];
buf_workspace = buf_zs[q];
buf_workspace.invert_bits();
buf_workspace &= buf_xs[q];
buf_signs ^= buf_workspace;
break;
case GateType::H:
buf_xs[q].swap_with(buf_zs[q]);
buf_workspace = buf_zs[q];
buf_workspace &= buf_xs[q];
buf_signs ^= buf_workspace;
break;
default:
throw std::invalid_argument("Unrecognized gate type.");
}
}

// Cancel other terms in Pauli string.
for (size_t q = 0; q < num_qubits; q++) {
int p = buf.xs[q] + buf.zs[q] * 2;
int p = buf_xs[q][k] + buf_zs[q][k] * 2;
if (p && q != pivot) {
std::array<GateTarget, 2> targets{GateTarget::qubit(pivot), GateTarget::qubit(q)};
CircuitInstruction instruction{p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY, {}, targets};
GateType g = p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY;
CircuitInstruction instruction{g, {}, targets};
elimination_instructions.safe_append(instruction);
buf.ref().do_instruction(instruction);
size_t q1 = targets[0].qubit_value();
size_t q2 = targets[1].qubit_value();
simd_bits_range_ref<W> x1 = buf_xs[q1];
simd_bits_range_ref<W> z1 = buf_zs[q1];
simd_bits_range_ref<W> x2 = buf_xs[q2];
simd_bits_range_ref<W> z2 = buf_zs[q2];
switch (g) {
case GateType::XCX:
buf_workspace = x1;
buf_workspace ^= x2;
buf_workspace &= z1;
buf_workspace &= z2;
buf_signs ^= buf_workspace;
x1 ^= z2;
x2 ^= z1;
break;
case GateType::XCY:
x1 ^= x2;
x1 ^= z2;
x2 ^= z1;
z2 ^= z1;
buf_workspace = x1;
buf_workspace |= x2;
buf_workspace.invert_bits();
buf_workspace &= z1;
buf_workspace &= z2;
buf_signs ^= buf_workspace;
buf_workspace = z2;
buf_workspace.invert_bits();
buf_workspace &= z1;
buf_workspace &= x1;
buf_workspace &= x2;
buf_signs ^= buf_workspace;
break;
case GateType::XCZ:
z2 ^= z1;
x1 ^= x2;
buf_workspace = z2;
buf_workspace ^= x1;
buf_workspace.invert_bits();
buf_workspace &= x2;
buf_workspace &= z1;
buf_signs ^= buf_workspace;
break;
default:
throw std::invalid_argument("Unrecognized gate type.");
}
}
}

Expand All @@ -635,13 +702,16 @@ Tableau<W> stabilizers_to_tableau(
std::array<GateTarget, 2> targets{GateTarget::qubit(pivot), GateTarget::qubit(used)};
CircuitInstruction instruction{GateType::SWAP, {}, targets};
elimination_instructions.safe_append(instruction);
buf_xs[pivot].swap_with(buf_xs[used]);
buf_zs[pivot].swap_with(buf_zs[used]);
}

// Fix sign.
if (buf.sign) {
if (buf_signs[k]) {
GateTarget t = GateTarget::qubit(used);
CircuitInstruction instruction{GateType::X, {}, &t};
elimination_instructions.safe_append(instruction);
buf_signs ^= buf_zs[used];
}

used++;
Expand Down
53 changes: 51 additions & 2 deletions src/stim/stabilizers/conversions.perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ BENCHMARK(independent_to_disjoint_xyz_errors) {
}
}

BENCHMARK(stabilizers_to_tableau) {
BENCHMARK(stabilizers_to_tableau_144) {
std::vector<std::complex<float>> offsets{
{1, 0},
{-1, 0},
Expand Down Expand Up @@ -110,7 +110,56 @@ BENCHMARK(stabilizers_to_tableau) {
benchmark_go([&]() {
Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false);
dep += t.xs[0].zs[0];
}).goal_millis(5);
}).goal_micros(500);
if (dep == 99999999) {
std::cout << "data dependence";
}
}


BENCHMARK(stabilizers_to_tableau_576) {
std::vector<std::complex<float>> offsets{
{1, 0},
{-1, 0},
{0, 1},
{0, -1},
{3, 6},
{-6, 3},
};
size_t w = 24*4;
size_t h = 12*4;

auto normalize = [&](std::complex<float> c) -> std::complex<float> {
return {fmodf(c.real() + w*10, w), fmodf(c.imag() + h*10, h)};
};
auto q2i = [&](std::complex<float> c) -> size_t {
c = normalize(c);
return (int)c.real() / 2 + c.imag() * (w / 2);
};

std::vector<stim::PauliString<64>> stabilizers;
for (size_t x = 0; x < w; x++) {
for (size_t y = x % 2; y < h; y += 2) {
std::complex<float> s{x % 2 ? -1.0f : +1.0f, 0.0f};
std::complex<float> c{(float)x, (float)y};
stim::PauliString<64> ps(w * h / 2);
for (const auto &offset : offsets) {
size_t i = q2i(c + offset * s);
if (x % 2 == 0) {
ps.xs[i] = 1;
} else {
ps.zs[i] = 1;
}
}
stabilizers.push_back(ps);
}
}

size_t dep = 0;
benchmark_go([&]() {
Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false);
dep += t.xs[0].zs[0];
}).goal_millis(200);
if (dep == 99999999) {
std::cout << "data dependence";
}
Expand Down

0 comments on commit 533e709

Please sign in to comment.