From 2feae5a2efac3b84fcc923147971de3e50943ab6 Mon Sep 17 00:00:00 2001 From: Matt McEwen Date: Fri, 8 Dec 2023 12:41:54 -0800 Subject: [PATCH] prevent running with scisors --- doc/python_api_reference_vDev.md | 18 +++++-- doc/stim.pyi | 18 +++++-- glue/python/src/stim/__init__.pyi | 18 +++++-- src/stim/simulators/frame_simulator.pybind.cc | 26 +++++++-- .../simulators/frame_simulator_pybind_test.py | 54 ++++++++++++++++++- 5 files changed, 116 insertions(+), 18 deletions(-) diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index 081c06d0b..4e9c3d6a1 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -6165,15 +6165,25 @@ def broadcast_pauli_errors( pauli: Union[str, int], mask: np.ndarray, ) -> None: - """Applies a pauli over all qubits in all simulation indices, filtered by mask. + """Applies a pauli error to all qubits in all instances, filtered by a mask. Args: pauli: The pauli, specified as an integer or string. Uses the convention 0=I, 1=X, 2=Y, 3=Z. Any value from [0, 1, 2, 3, 'X', 'Y', 'Z', 'I', '_'] is allowed. - mask: a np.bool_ array with shape (qubit, simulation_instance) - The pauli error is only applied to qubits q and simulation indices k - where mask[q, k] == True + mask: A 2d numpy array specifying where to apply errors. The first axis + is qubits, the second axis is simulation instances. The first axis + can have a length less than the current number of qubits (or more, + which adds qubits to the simulation). The length of the second axis + must match the simulator's `batch_size`. The array must satisfy + + mask.dtype == np.bool_ + len(mask.shape) == 2 + mask.shape[1] == flip_sim.batch_size + + The error is only applied to qubit q in instance k when + + mask[q, k] == True. Examples: >>> import stim diff --git a/doc/stim.pyi b/doc/stim.pyi index 67c2f3677..6a7ff42f8 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -4676,15 +4676,25 @@ class FlipSimulator: pauli: Union[str, int], mask: np.ndarray, ) -> None: - """Applies a pauli over all qubits in all simulation indices, filtered by mask. + """Applies a pauli error to all qubits in all instances, filtered by a mask. Args: pauli: The pauli, specified as an integer or string. Uses the convention 0=I, 1=X, 2=Y, 3=Z. Any value from [0, 1, 2, 3, 'X', 'Y', 'Z', 'I', '_'] is allowed. - mask: a np.bool_ array with shape (qubit, simulation_instance) - The pauli error is only applied to qubits q and simulation indices k - where mask[q, k] == True + mask: A 2d numpy array specifying where to apply errors. The first axis + is qubits, the second axis is simulation instances. The first axis + can have a length less than the current number of qubits (or more, + which adds qubits to the simulation). The length of the second axis + must match the simulator's `batch_size`. The array must satisfy + + mask.dtype == np.bool_ + len(mask.shape) == 2 + mask.shape[1] == flip_sim.batch_size + + The error is only applied to qubit q in instance k when + + mask[q, k] == True. Examples: >>> import stim diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index 67c2f3677..6a7ff42f8 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -4676,15 +4676,25 @@ class FlipSimulator: pauli: Union[str, int], mask: np.ndarray, ) -> None: - """Applies a pauli over all qubits in all simulation indices, filtered by mask. + """Applies a pauli error to all qubits in all instances, filtered by a mask. Args: pauli: The pauli, specified as an integer or string. Uses the convention 0=I, 1=X, 2=Y, 3=Z. Any value from [0, 1, 2, 3, 'X', 'Y', 'Z', 'I', '_'] is allowed. - mask: a np.bool_ array with shape (qubit, simulation_instance) - The pauli error is only applied to qubits q and simulation indices k - where mask[q, k] == True + mask: A 2d numpy array specifying where to apply errors. The first axis + is qubits, the second axis is simulation instances. The first axis + can have a length less than the current number of qubits (or more, + which adds qubits to the simulation). The length of the second axis + must match the simulator's `batch_size`. The array must satisfy + + mask.dtype == np.bool_ + len(mask.shape) == 2 + mask.shape[1] == flip_sim.batch_size + + The error is only applied to qubit q in instance k when + + mask[q, k] == True. Examples: >>> import stim diff --git a/src/stim/simulators/frame_simulator.pybind.cc b/src/stim/simulators/frame_simulator.pybind.cc index f521cad5c..4ac38ec36 100644 --- a/src/stim/simulators/frame_simulator.pybind.cc +++ b/src/stim/simulators/frame_simulator.pybind.cc @@ -820,6 +820,14 @@ void stim_pybind::pybind_frame_simulator_methods( size_t major = arr.shape(0); size_t minor = arr.shape(1); + + if (minor != self.batch_size) { + throw std::invalid_argument( + "broadcast_pauli_errors can only accept mask that has minor shape equal to the batch_size"); + } + + self.ensure_safe_to_do_circuit_with_stats(CircuitStats{.num_qubits=(uint32_t)major}); + auto u = arr.unchecked<2>(); for (size_t i = 0; i < major; i++){ for (size_t j = 0; j < minor; j++){ @@ -834,15 +842,25 @@ void stim_pybind::pybind_frame_simulator_methods( pybind11::arg("mask"), clean_doc_string(R"DOC( @signature def broadcast_pauli_errors(self, *, pauli: Union[str, int], mask: np.ndarray) -> None: - Applies a pauli over all qubits in all simulation indices, filtered by mask. + Applies a pauli error to all qubits in all instances, filtered by a mask. Args: pauli: The pauli, specified as an integer or string. Uses the convention 0=I, 1=X, 2=Y, 3=Z. Any value from [0, 1, 2, 3, 'X', 'Y', 'Z', 'I', '_'] is allowed. - mask: a np.bool_ array with shape (qubit, simulation_instance) - The pauli error is only applied to qubits q and simulation indices k - where mask[q, k] == True + mask: A 2d numpy array specifying where to apply errors. The first axis + is qubits, the second axis is simulation instances. The first axis + can have a length less than the current number of qubits (or more, + which adds qubits to the simulation). The length of the second axis + must match the simulator's `batch_size`. The array must satisfy + + mask.dtype == np.bool_ + len(mask.shape) == 2 + mask.shape[1] == flip_sim.batch_size + + The error is only applied to qubit q in instance k when + + mask[q, k] == True. Examples: >>> import stim diff --git a/src/stim/simulators/frame_simulator_pybind_test.py b/src/stim/simulators/frame_simulator_pybind_test.py index 8ceb96283..63f1dd7d9 100644 --- a/src/stim/simulators/frame_simulator_pybind_test.py +++ b/src/stim/simulators/frame_simulator_pybind_test.py @@ -334,7 +334,7 @@ def test_broadcast_pauli_errors(): stim.PauliString("+ZYZ") ] - with pytest.raises(Exception): + with pytest.raises(ValueError, match='pauli'): sim.broadcast_pauli_errors( pauli='whoops', mask=np.asarray([ @@ -343,7 +343,7 @@ def test_broadcast_pauli_errors(): [True, True]] ), ) - with pytest.raises(Exception): + with pytest.raises(ValueError, match='pauli'): sim.broadcast_pauli_errors( pauli=4, mask=np.asarray([ @@ -352,6 +352,56 @@ def test_broadcast_pauli_errors(): [True, True]] ), ) + with pytest.raises(ValueError, match='batch_size'): + sim.broadcast_pauli_errors( + pauli='X', + mask=np.asarray([ + [True, True,True], + [False, True, True], + [True, True, True]] + ), + ) + with pytest.raises(ValueError, match='batch_size'): + sim.broadcast_pauli_errors( + pauli='X', + mask=np.asarray([ + [True], + [False], + [True]] + ), + ) + sim = stim.FlipSimulator( + batch_size=2, + num_qubits=3, + disable_stabilizer_randomization=True, + ) + sim.broadcast_pauli_errors( + pauli='X', + mask=np.asarray([ + [True, False], + [False, False], + [True, True], + [True, True]] + ), + ) # automatically expands the qubit basis + peek = sim.peek_pauli_flips() + assert peek == [ + stim.PauliString("+X_XX"), + stim.PauliString("+__XX") + ] + sim.broadcast_pauli_errors( + pauli='X', + mask=np.asarray([ + [True, False], + [False, False], + ] + ), + ) # tolerates fewer qubits in mask than in simulator + peek = sim.peek_pauli_flips() + assert peek == [ + stim.PauliString("+__XX"), + stim.PauliString("+__XX") + ] def test_repro_heralded_pauli_channel_1_bug():