From d5a67ccae4400e07c75a56981ea0381cd847af70 Mon Sep 17 00:00:00 2001 From: noajshu Date: Tue, 27 Feb 2024 23:26:44 +0000 Subject: [PATCH] Add WCNF generator to find minimum distance of stabilizer protocols --- file_lists/source_files_no_main | 1 + file_lists/test_files | 1 + src/stim/circuit/circuit.pybind.cc | 44 +++++ src/stim/circuit/circuit_pybind_test.py | 13 ++ src/stim/search/sat/wcnf.cc | 220 ++++++++++++++++++++++++ src/stim/search/sat/wcnf.h | 35 ++++ src/stim/search/sat/wcnf.test.cc | 49 ++++++ src/stim/search/search.h | 1 + 8 files changed, 364 insertions(+) create mode 100644 src/stim/search/sat/wcnf.cc create mode 100644 src/stim/search/sat/wcnf.h create mode 100644 src/stim/search/sat/wcnf.test.cc diff --git a/file_lists/source_files_no_main b/file_lists/source_files_no_main index b9694b6e5..0f96375e5 100644 --- a/file_lists/source_files_no_main +++ b/file_lists/source_files_no_main @@ -77,6 +77,7 @@ src/stim/search/hyper/edge.cc src/stim/search/hyper/graph.cc src/stim/search/hyper/node.cc src/stim/search/hyper/search_state.cc +src/stim/search/sat/wcnf.cc src/stim/simulators/error_analyzer.cc src/stim/simulators/error_matcher.cc src/stim/simulators/force_streaming.cc diff --git a/file_lists/test_files b/file_lists/test_files index b08819aff..a2e0ddbe7 100644 --- a/file_lists/test_files +++ b/file_lists/test_files @@ -56,6 +56,7 @@ src/stim/search/hyper/edge.test.cc src/stim/search/hyper/graph.test.cc src/stim/search/hyper/node.test.cc src/stim/search/hyper/search_state.test.cc +src/stim/search/sat/wcnf.test.cc src/stim/simulators/count_determined_measurements.test.cc src/stim/simulators/dem_sampler.test.cc src/stim/simulators/error_analyzer.test.cc diff --git a/src/stim/circuit/circuit.pybind.cc b/src/stim/circuit/circuit.pybind.cc index d0e0c6349..ba29d426e 100644 --- a/src/stim/circuit/circuit.pybind.cc +++ b/src/stim/circuit/circuit.pybind.cc @@ -80,6 +80,11 @@ std::vector py_find_undetectable_logical_error( return ErrorMatcher::explain_errors_from_circuit(self, &filter, reduce_to_representative); } +std::string py_shortest_undetectable_logical_error_wcnf(const Circuit& self, size_t num_distinct_weights) { + DetectorErrorModel dem = ErrorAnalyzer::circuit_to_detector_error_model(self, false, true, false, 1, false, false); + return stim::shortest_undetectable_logical_error_wcnf(dem, num_distinct_weights); +} + void circuit_append( Circuit &self, const pybind11::object &obj, @@ -1966,6 +1971,45 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_ 1, the weights of + the errors will be quantized accordingly and the sum of the weights will + be minimized. For a reasonably large quantization (num_distinct_weights > + 100), the .wcnf file solution should be the (approximately) most likely + undetectable logical error. Note, however, that maxSAT solvers often + become slower when many distinct weights are provided, so for computing + the distance it is better to use the default quantization + num_distinct_weights = 1. + + Returns: + A WCNF file in [WDIMACS format](http://www.maxhs.org/docs/wdimacs.html) + + Examples: + >>> import stim + >>> circuit = stim.Circuit.generated( + ... "surface_code:rotated_memory_x", + ... rounds=5, + ... distance=5, + ... after_clifford_depolarization=0.001) + >>> print(circuit.shortest_undetectable_logical_error_wcnf( + num_distinct_weights=1)) + .... + >>> print(circuit.shortest_undetectable_logical_error_wcnf( + num_distinct_weights=10)) + .... + )DOC") + .data()); c.def( "explain_detector_error_model_errors", [](const Circuit &self, diff --git a/src/stim/circuit/circuit_pybind_test.py b/src/stim/circuit/circuit_pybind_test.py index 4dca08503..f71150338 100644 --- a/src/stim/circuit/circuit_pybind_test.py +++ b/src/stim/circuit/circuit_pybind_test.py @@ -827,6 +827,19 @@ def test_search_for_undetectable_logical_errors_msgs(): dont_explore_detection_event_sets_with_size_above=4, ) +def test_shortest_undetectable_logical_error_wcnf(): + c = stim.Circuit(""" + X_ERROR(0.1) 0 + M 0 + OBSERVABLE_INCLUDE(0) rec[-1] + X_ERROR(0.4) 0 + M 0 + DETECTOR rec[-1] rec[-2] + """) + wcnf_str = c.shortest_undetectable_logical_error_wcnf() + assert wcnf_str == 'p wcnf 2 3 4\n1 -1 0\n1 -2 0\n4 2 0\n' + wcnf_str = c.shortest_undetectable_logical_error_wcnf(num_distinct_weights=2) + assert wcnf_str == 'p wcnf 2 3 7\n1 -1 0\n2 -2 0\n7 2 0\n' def test_shortest_graphlike_error_ignore(): c = stim.Circuit(""" diff --git a/src/stim/search/sat/wcnf.cc b/src/stim/search/sat/wcnf.cc new file mode 100644 index 000000000..c4a3078ca --- /dev/null +++ b/src/stim/search/sat/wcnf.cc @@ -0,0 +1,220 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "stim/search/sat/wcnf.h" + + +using namespace stim; + +typedef double Weight; +constexpr Weight HARD_CLAUSE_WEIGHT = -1.0; + +constexpr size_t BOOL_LITERAL_FALSE = SIZE_MAX - 1; +constexpr size_t BOOL_LITERAL_TRUE = SIZE_MAX; +struct BoolRef { + size_t variable = BOOL_LITERAL_FALSE; + bool negated = false; + BoolRef operator~() const { + return {variable, !negated}; + } + static BoolRef False() { + return {BOOL_LITERAL_FALSE, false}; + } + static BoolRef True() { + return {BOOL_LITERAL_TRUE, false}; + } +}; + +struct Clause { + std::vector vars; + Weight weight = HARD_CLAUSE_WEIGHT; + void add_var(BoolRef x) { + vars.push_back(x); + } +}; + +struct MaxSATInstance { + size_t num_variables = 0; + Weight max_weight = 0; + std::vector clauses; + BoolRef new_bool() { + return {num_variables++}; + } + void add_clause(Clause& clause) { + if (clause.weight != HARD_CLAUSE_WEIGHT) { + if (clause.weight <= 0) { + throw std::invalid_argument("Clauses must have positive weight or HARD_CLAUSE_WEIGHT."); + } + max_weight = std::max(max_weight, clause.weight); + } + clauses.push_back(clause); + } + BoolRef Xor(BoolRef& x, BoolRef& y) { + if (x.variable == BOOL_LITERAL_FALSE) { + return y; + } + if (x.variable == BOOL_LITERAL_TRUE) { + return ~y; + } + if (y.variable == BOOL_LITERAL_FALSE) { + return x; + } + if (y.variable == BOOL_LITERAL_TRUE) { + return ~x; + } + BoolRef z = new_bool(); + // Forbid strings (x, y, z) such that z != XOR(x, y) + { + Clause clause; + // hard clause (x, y, z) != (0, 0, 1) + clause.add_var(x); + clause.add_var(y); + clause.add_var(~z); + add_clause(clause); + } + { + Clause clause; + // hard clause (x, y, z) != (0, 1, 0) + clause.add_var(x); + clause.add_var(~y); + clause.add_var(z); + add_clause(clause); + } + { + Clause clause; + // hard clause (x, y, z) != (1, 0, 0) + clause.add_var(~x); + clause.add_var(y); + clause.add_var(z); + add_clause(clause); + } + { + Clause clause; + // hard clause (x, y, z) != (1, 1, 1) + clause.add_var(~x); + clause.add_var(~y); + clause.add_var(~z); + add_clause(clause); + } + return z; + } + + size_t quantized_weight(size_t num_distinct_weights, size_t top, Weight weight) { + if (weight == HARD_CLAUSE_WEIGHT) { + return top; + } + return std::max(1ul, (size_t)std::round(weight / max_weight * (double)num_distinct_weights)); + } + + std::string to_wdimacs(size_t num_distinct_weights=1) { + if (num_distinct_weights < 1) { + throw std::invalid_argument("There must be at least 1 distinct weight value"); + } + // 'top' is a special weight used to indicate a hard clause. + // Should be at least the sum of the weights of all soft clauses plus 1. + size_t top = 1 + num_distinct_weights * clauses.size(); + + // WDIMACS header format: p wcnf nbvar nbclauses top + // see http://www.maxhs.org/docs/wdimacs.html + std::stringstream ss; + ss << "p wcnf "<< num_variables << " "<= 1"); + } + + MaxSATInstance instance; + // Create a boolean variable for each error, which indicates whether it is activated. + std::vector errors_activated; + for (size_t i=0; i detectors_activated(num_detectors, BoolRef::False()); + std::vector observables_flipped(num_observables, BoolRef::False()); + + size_t error_index = 0; + model.iter_flatten_error_instructions([&](const DemInstruction &e) { + if (e.arg_data[0] != 0) { + BoolRef err_x = errors_activated[error_index]; + // Add parity contribution to the detectors and observables + for (const auto &t : e.target_data) { + if (t.is_relative_detector_id()) { + detectors_activated[t.val()] = instance.Xor(detectors_activated[t.val()], err_x); + } else if (t.is_observable_id()) { + observables_flipped[t.val()] = instance.Xor(observables_flipped[t.val()], err_x); + } + } + // Add a soft clause for this error to be inactive + Clause clause; + clause.add_var(~err_x); + clause.weight = -std::log(e.arg_data[0] / (1 - e.arg_data[0])); + instance.add_clause(clause); + } + ++error_index; + }); + + // Add a hard clause for any observable to be flipped + Clause clause; + for (size_t i=0; i + +#include "stim/dem/detector_error_model.h" + +namespace stim { + +/// Generates a maxSAT problem instance in .wcnf format from a DetectorErrorModel, such that the optimal value of the instance corresponds to the minimum distance of the protocol. +/// +/// The .wcnf (weighted CNF) file format is widely +/// accepted by numerous maxSAT solvers. For example, the solvers in the 2023 maxSAT competition: https://maxsat-evaluations.github.io/2023/descriptions.html +/// Note that erformance can greatly vary among solvers. +/// The conversion involves encoding XOR constraints into CNF clauses using standard techniques. +/// +/// Args: +/// model: The detector error model to be converted into .wcnf format for minimum distance calculation. +/// num_distinct_weights: The number of different integer weight values for quantization (default = 1). +/// +/// Returns: +/// A string which is interpreted as the contents of a .wcnf file. This should be written to a file which can then be passed to various maxSAT +/// solvers to determine the minimum distance of the protocol represented by the model. The optimal value found +/// by the solver corresponds to the minimum distance of the error correction protocol. In other words, the smallest number of errors that cause a logical observable flip without any detection events. +/// +/// Note: +/// The use of .wcnf format offers significant flexibility in choosing a maxSAT solver, but it also means that +/// users must separately manage the process of selecting and running the solver. This approach is designed to +/// sidestep the need for direct integration with any particular solver and allow +/// for experimentation with different solvers to achieve the best performance. +std::string shortest_undetectable_logical_error_wcnf(const DetectorErrorModel &model, size_t num_distinct_weights=1); + +} // namespace stim + +#endif diff --git a/src/stim/search/sat/wcnf.test.cc b/src/stim/search/sat/wcnf.test.cc new file mode 100644 index 000000000..ccac79892 --- /dev/null +++ b/src/stim/search/sat/wcnf.test.cc @@ -0,0 +1,49 @@ +#include "stim/search/sat/wcnf.h" +#include "gtest/gtest.h" + +using namespace stim; + +TEST(shortest_undetectable_logical_error_wcnf, no_error) { + // No error. + ASSERT_THROW( + { stim::shortest_undetectable_logical_error_wcnf(DetectorErrorModel()); }, + std::invalid_argument); +} + +TEST(shortest_undetectable_logical_error_wcnf, single_detector_single_observable) { + std::string wcnf = stim::shortest_undetectable_logical_error_wcnf(DetectorErrorModel(R"DEM( + error(0.1) D0 L0 + error(0.1) D0 + )DEM")); + // There should be 3 variables: x = (x_0, x_1, x_2) + // x_0 -- error 0 occurred + // x_1 -- error 1 occurred + // x_2 -- XOR of x_0 and x_1 + // There should be 2 soft clauses: + // soft clause NOT(x_0) with weight 1 + // soft clause NOT(x_0) with weight 1 + // There should be 4 hard clauses to forbid the strings such that x_2 != XOR(x_0, x_1): + // hard clause x != (0, 0, 1) + // hard clause x != (0, 1, 0) + // hard clause x != (1, 0, 0) + // hard clause x != (1, 1, 1) + // Plus 1 hard clause to ensure an observable is flipped: + // hard clause x_0 + // This gives a total of 7 clauses + // The top value should be at least 1 + 1 + 1 = 3. In our implementation ends up being 8. + std::stringstream expected; + // WDIMACS header format: p wcnf nbvar nbclauses top + expected << "p wcnf 3 7 8\n"; + // Soft clause + expected << "1 -0\n"; + // Hard clauses + expected << "8 0 1 -2\n"; + expected << "8 0 -1 2\n"; + expected << "8 -0 1 2\n"; + expected << "8 -0 -1 -2\n"; + // Soft clause + expected << "1 -1\n"; + // Hard clause for the observable flipped + expected << "8 0\n"; + ASSERT_EQ(wcnf, expected.str()); +} diff --git a/src/stim/search/search.h b/src/stim/search/search.h index 67e8fbc3b..b110146ca 100644 --- a/src/stim/search/search.h +++ b/src/stim/search/search.h @@ -19,5 +19,6 @@ #include "stim/search/graphlike/algo.h" #include "stim/search/hyper/algo.h" +#include "stim/search/sat/wcnf.h" #endif