Skip to content

Commit

Permalink
Add WCNF generator to find minimum distance of stabilizer protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
noajshu committed Feb 27, 2024
1 parent f11b4f9 commit d5a67cc
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 0 deletions.
1 change: 1 addition & 0 deletions file_lists/source_files_no_main
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ src/stim/search/hyper/edge.cc
src/stim/search/hyper/graph.cc
src/stim/search/hyper/node.cc
src/stim/search/hyper/search_state.cc
src/stim/search/sat/wcnf.cc
src/stim/simulators/error_analyzer.cc
src/stim/simulators/error_matcher.cc
src/stim/simulators/force_streaming.cc
Expand Down
1 change: 1 addition & 0 deletions file_lists/test_files
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ src/stim/search/hyper/edge.test.cc
src/stim/search/hyper/graph.test.cc
src/stim/search/hyper/node.test.cc
src/stim/search/hyper/search_state.test.cc
src/stim/search/sat/wcnf.test.cc
src/stim/simulators/count_determined_measurements.test.cc
src/stim/simulators/dem_sampler.test.cc
src/stim/simulators/error_analyzer.test.cc
Expand Down
44 changes: 44 additions & 0 deletions src/stim/circuit/circuit.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ std::vector<ExplainedError> py_find_undetectable_logical_error(
return ErrorMatcher::explain_errors_from_circuit(self, &filter, reduce_to_representative);
}

std::string py_shortest_undetectable_logical_error_wcnf(const Circuit& self, size_t num_distinct_weights) {
DetectorErrorModel dem = ErrorAnalyzer::circuit_to_detector_error_model(self, false, true, false, 1, false, false);
return stim::shortest_undetectable_logical_error_wcnf(dem, num_distinct_weights);
}

void circuit_append(
Circuit &self,
const pybind11::object &obj,
Expand Down Expand Up @@ -1966,6 +1971,45 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_<Ci
5
)DOC")
.data());

c.def(
"shortest_undetectable_logical_error_wcnf",
&py_shortest_undetectable_logical_error_wcnf,
pybind11::kw_only(),
pybind11::arg("num_distinct_weights") = 1,
clean_doc_string(R"DOC(
Generates a maxSAT problem instance in WDIMACS format whose optimal value is
the distance of the protocol, i.e. the minimum weight of any set of errors
that forms an undetectable logical error.
Args:
num_distinct_weights: Defaults to 1 (unweighted). If > 1, the weights of
the errors will be quantized accordingly and the sum of the weights will
be minimized. For a reasonably large quantization (num_distinct_weights >
100), the .wcnf file solution should be the (approximately) most likely
undetectable logical error. Note, however, that maxSAT solvers often
become slower when many distinct weights are provided, so for computing
the distance it is better to use the default quantization
num_distinct_weights = 1.
Returns:
A WCNF file in [WDIMACS format](http://www.maxhs.org/docs/wdimacs.html)
Examples:
>>> import stim
>>> circuit = stim.Circuit.generated(
... "surface_code:rotated_memory_x",
... rounds=5,
... distance=5,
... after_clifford_depolarization=0.001)
>>> print(circuit.shortest_undetectable_logical_error_wcnf(
num_distinct_weights=1))
....
>>> print(circuit.shortest_undetectable_logical_error_wcnf(
num_distinct_weights=10))
....
)DOC")
.data());
c.def(
"explain_detector_error_model_errors",
[](const Circuit &self,
Expand Down
13 changes: 13 additions & 0 deletions src/stim/circuit/circuit_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,19 @@ def test_search_for_undetectable_logical_errors_msgs():
dont_explore_detection_event_sets_with_size_above=4,
)

def test_shortest_undetectable_logical_error_wcnf():
c = stim.Circuit("""
X_ERROR(0.1) 0
M 0
OBSERVABLE_INCLUDE(0) rec[-1]
X_ERROR(0.4) 0
M 0
DETECTOR rec[-1] rec[-2]
""")
wcnf_str = c.shortest_undetectable_logical_error_wcnf()
assert wcnf_str == 'p wcnf 2 3 4\n1 -1 0\n1 -2 0\n4 2 0\n'
wcnf_str = c.shortest_undetectable_logical_error_wcnf(num_distinct_weights=2)
assert wcnf_str == 'p wcnf 2 3 7\n1 -1 0\n2 -2 0\n7 2 0\n'

def test_shortest_graphlike_error_ignore():
c = stim.Circuit("""
Expand Down
220 changes: 220 additions & 0 deletions src/stim/search/sat/wcnf.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include "stim/search/sat/wcnf.h"


using namespace stim;

typedef double Weight;
constexpr Weight HARD_CLAUSE_WEIGHT = -1.0;

constexpr size_t BOOL_LITERAL_FALSE = SIZE_MAX - 1;
constexpr size_t BOOL_LITERAL_TRUE = SIZE_MAX;
struct BoolRef {
size_t variable = BOOL_LITERAL_FALSE;
bool negated = false;
BoolRef operator~() const {
return {variable, !negated};
}
static BoolRef False() {
return {BOOL_LITERAL_FALSE, false};
}
static BoolRef True() {
return {BOOL_LITERAL_TRUE, false};
}
};

struct Clause {
std::vector<BoolRef> vars;
Weight weight = HARD_CLAUSE_WEIGHT;
void add_var(BoolRef x) {
vars.push_back(x);
}
};

struct MaxSATInstance {
size_t num_variables = 0;
Weight max_weight = 0;
std::vector<Clause> clauses;
BoolRef new_bool() {
return {num_variables++};
}
void add_clause(Clause& clause) {
if (clause.weight != HARD_CLAUSE_WEIGHT) {
if (clause.weight <= 0) {
throw std::invalid_argument("Clauses must have positive weight or HARD_CLAUSE_WEIGHT.");
}
max_weight = std::max(max_weight, clause.weight);
}
clauses.push_back(clause);
}
BoolRef Xor(BoolRef& x, BoolRef& y) {
if (x.variable == BOOL_LITERAL_FALSE) {
return y;
}
if (x.variable == BOOL_LITERAL_TRUE) {
return ~y;
}
if (y.variable == BOOL_LITERAL_FALSE) {
return x;
}
if (y.variable == BOOL_LITERAL_TRUE) {
return ~x;
}
BoolRef z = new_bool();
// Forbid strings (x, y, z) such that z != XOR(x, y)
{
Clause clause;
// hard clause (x, y, z) != (0, 0, 1)
clause.add_var(x);
clause.add_var(y);
clause.add_var(~z);
add_clause(clause);
}
{
Clause clause;
// hard clause (x, y, z) != (0, 1, 0)
clause.add_var(x);
clause.add_var(~y);
clause.add_var(z);
add_clause(clause);
}
{
Clause clause;
// hard clause (x, y, z) != (1, 0, 0)
clause.add_var(~x);
clause.add_var(y);
clause.add_var(z);
add_clause(clause);
}
{
Clause clause;
// hard clause (x, y, z) != (1, 1, 1)
clause.add_var(~x);
clause.add_var(~y);
clause.add_var(~z);
add_clause(clause);
}
return z;
}

size_t quantized_weight(size_t num_distinct_weights, size_t top, Weight weight) {
if (weight == HARD_CLAUSE_WEIGHT) {
return top;
}
return std::max(1ul, (size_t)std::round(weight / max_weight * (double)num_distinct_weights));
}

std::string to_wdimacs(size_t num_distinct_weights=1) {
if (num_distinct_weights < 1) {
throw std::invalid_argument("There must be at least 1 distinct weight value");
}
// 'top' is a special weight used to indicate a hard clause.
// Should be at least the sum of the weights of all soft clauses plus 1.
size_t top = 1 + num_distinct_weights * clauses.size();

// WDIMACS header format: p wcnf nbvar nbclauses top
// see http://www.maxhs.org/docs/wdimacs.html
std::stringstream ss;
ss << "p wcnf "<< num_variables << " "<<clauses.size() << " " << top << "\n";

// Add clauses, 1 on each line.
for (const auto& clause : clauses) {
// WDIMACS clause format: weight var1 var2 ...
// To show negation of a variable, the index should be negated.
ss << quantized_weight(num_distinct_weights, top, clause.weight);
for (size_t i=0; i<clause.vars.size(); ++i) {
BoolRef var = clause.vars[i];
// Variables are 1-indexed
if (var.negated) {
ss << " -" << (var.variable + 1);
} else {
ss << " " << (var.variable + 1);
}
}
// Each clause ends with 0
ss << " 0\n";
}
return ss.str();
}
};

std::string stim::shortest_undetectable_logical_error_wcnf(const DetectorErrorModel &model, size_t num_distinct_weights) {
MaxSATInstance inst;

size_t num_observables = model.count_observables();
size_t num_detectors = model.count_detectors();
size_t num_errors = model.count_errors();
if (num_observables == 0 or num_detectors == 0 or num_errors == 0) {
std::stringstream err_msg;
err_msg << "Failed to find any logical errors.";
if (num_observables == 0) {
err_msg << "\n WARNING: NO OBSERVABLES. The circuit or detector error model didn't define any observables, "
"making it vacuously impossible to find a logical error.";
}
if (num_detectors == 0) {
err_msg << "\n WARNING: NO DETECTORS. The circuit or detector error model didn't define any detectors.";
}
if (num_errors == 0) {
err_msg << "\n WARNING: NO ERRORS. The circuit or detector error model didn't include any errors, making it "
"vacuously impossible to find a logical error.";
}
throw std::invalid_argument(err_msg.str());
}
if (num_distinct_weights == 0) {
throw std::invalid_argument("num_distinct_weights must be >= 1");
}

MaxSATInstance instance;
// Create a boolean variable for each error, which indicates whether it is activated.
std::vector<BoolRef> errors_activated;
for (size_t i=0; i<num_errors; ++i) {
errors_activated.push_back(instance.new_bool());
}

std::vector<BoolRef> detectors_activated(num_detectors, BoolRef::False());
std::vector<BoolRef> observables_flipped(num_observables, BoolRef::False());

size_t error_index = 0;
model.iter_flatten_error_instructions([&](const DemInstruction &e) {
if (e.arg_data[0] != 0) {
BoolRef err_x = errors_activated[error_index];
// Add parity contribution to the detectors and observables
for (const auto &t : e.target_data) {
if (t.is_relative_detector_id()) {
detectors_activated[t.val()] = instance.Xor(detectors_activated[t.val()], err_x);
} else if (t.is_observable_id()) {
observables_flipped[t.val()] = instance.Xor(observables_flipped[t.val()], err_x);
}
}
// Add a soft clause for this error to be inactive
Clause clause;
clause.add_var(~err_x);
clause.weight = -std::log(e.arg_data[0] / (1 - e.arg_data[0]));
instance.add_clause(clause);
}
++error_index;
});

// Add a hard clause for any observable to be flipped
Clause clause;
for (size_t i=0; i<num_observables; ++i) {
clause.add_var(observables_flipped[i]);
}
instance.add_clause(clause);

return instance.to_wdimacs(num_distinct_weights);
}
35 changes: 35 additions & 0 deletions src/stim/search/sat/wcnf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef _STIM_SEARCH_SAT_WCNF_H
#define _STIM_SEARCH_SAT_WCNF_H

#include <cstdint>

#include "stim/dem/detector_error_model.h"

namespace stim {

/// Generates a maxSAT problem instance in .wcnf format from a DetectorErrorModel, such that the optimal value of the instance corresponds to the minimum distance of the protocol.
///
/// The .wcnf (weighted CNF) file format is widely
/// accepted by numerous maxSAT solvers. For example, the solvers in the 2023 maxSAT competition: https://maxsat-evaluations.github.io/2023/descriptions.html
/// Note that erformance can greatly vary among solvers.
/// The conversion involves encoding XOR constraints into CNF clauses using standard techniques.
///
/// Args:
/// model: The detector error model to be converted into .wcnf format for minimum distance calculation.
/// num_distinct_weights: The number of different integer weight values for quantization (default = 1).
///
/// Returns:
/// A string which is interpreted as the contents of a .wcnf file. This should be written to a file which can then be passed to various maxSAT
/// solvers to determine the minimum distance of the protocol represented by the model. The optimal value found
/// by the solver corresponds to the minimum distance of the error correction protocol. In other words, the smallest number of errors that cause a logical observable flip without any detection events.
///
/// Note:
/// The use of .wcnf format offers significant flexibility in choosing a maxSAT solver, but it also means that
/// users must separately manage the process of selecting and running the solver. This approach is designed to
/// sidestep the need for direct integration with any particular solver and allow
/// for experimentation with different solvers to achieve the best performance.
std::string shortest_undetectable_logical_error_wcnf(const DetectorErrorModel &model, size_t num_distinct_weights=1);

} // namespace stim

#endif
Loading

0 comments on commit d5a67cc

Please sign in to comment.