Skip to content

Commit

Permalink
Finish up
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Dec 10, 2023
1 parent 582ade0 commit 03c4306
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 75 deletions.
120 changes: 50 additions & 70 deletions src/stim/simulators/frame_simulator.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,32 @@ std::optional<size_t> py_index_to_optional_size_t(
return (size_t)i;
}

uint8_t pybind11_object_to_pauli_ixyz(const pybind11::object &obj) {
if (pybind11::isinstance<pybind11::str>(obj)) {
std::string s = pybind11::cast<std::string>(obj);
if (s == "X") {
return 1;
} else if (s == "Y") {
return 2;
} else if (s == "Z") {
return 3;
} else if (s == "I" || s == "_") {
return 0;
}
} else if (pybind11::isinstance<pybind11::int_>(obj)) {
uint8_t v = 255;
try {
v = pybind11::cast<uint8_t>(obj);
} catch (const pybind11::cast_error &) {
}
if (v < 4) {
return (uint8_t)v;
}
}

throw std::invalid_argument("Need pauli in ['I', 'X', 'Y', 'Z', 0, 1, 2, 3, '_'].");
}

pybind11::class_<FrameSimulator<MAX_BITWORD_WIDTH>> stim_pybind::pybind_frame_simulator(pybind11::module &m) {
return pybind11::class_<FrameSimulator<MAX_BITWORD_WIDTH>>(
m,
Expand Down Expand Up @@ -374,27 +400,7 @@ void stim_pybind::pybind_frame_simulator_methods(
const pybind11::object &pauli,
int64_t qubit_index,
int64_t instance_index) {
uint8_t p = 255;
try {
p = pybind11::cast<uint8_t>(pauli);
} catch (const pybind11::cast_error &) {
try {
std::string s = pybind11::cast<std::string>(pauli);
if (s == "X") {
p = 1;
} else if (s == "Y") {
p = 2;
} else if (s == "Z") {
p = 3;
} else if (s == "I" || s == "_") {
p = 0;
}
} catch (const pybind11::cast_error &) {
}
}
if (p > 3) {
throw std::invalid_argument("Expected pauli in [0, 1, 2, 3, '_', 'I', 'X', 'Y', 'Z']");
}
uint8_t p = pybind11_object_to_pauli_ixyz(pauli);
if (instance_index < 0) {
instance_index += self.batch_size;
}
Expand All @@ -409,6 +415,7 @@ void stim_pybind::pybind_frame_simulator_methods(
stats.num_qubits = qubit_index + 1;
self.ensure_safe_to_do_circuit_with_stats(stats);
}

p ^= p >> 1;
self.x_table[qubit_index][instance_index] = (p & 1) != 0;
self.z_table[qubit_index][instance_index] = (p & 2) != 0;
Expand Down Expand Up @@ -773,67 +780,40 @@ void stim_pybind::pybind_frame_simulator_methods(

c.def(
"broadcast_pauli_errors",
[](FrameSimulator<MAX_BITWORD_WIDTH> &self,
const pybind11::object &pauli,
const pybind11::object &mask
) {
uint8_t p = 255;
try {
p = pybind11::cast<uint8_t>(pauli);
if (p >= 4) {
throw pybind11::cast_error();
}
} catch (const pybind11::cast_error &) {
try {
std::string s = pybind11::cast<std::string>(pauli);
if (s == "X") {
p = 1;
} else if (s == "Y") {
p = 2;
} else if (s == "Z") {
p = 3;
} else if (s == "I" || s == "_") {
p = 0;
} else {
throw pybind11::cast_error();
}
} catch (const pybind11::cast_error &) {
throw std::invalid_argument(
"broadcast_pauli_errors only accepts pauli arguments in ['I', '_', 'X', 'Y', 'Z', 0,1,2,3]");
}
}

bool flip_z_part = p & 2;
bool flip_x_part = (0b0110 >> p) & 1; // parity of 2 bit number

[](FrameSimulator<MAX_BITWORD_WIDTH> &self, const pybind11::object &pauli, const pybind11::object &mask) {
uint8_t p = pybind11_object_to_pauli_ixyz(pauli);

if (!pybind11::isinstance<pybind11::array_t<bool>>(mask)) {
throw std::invalid_argument(
"broadcast_pauli_errors can only accept mask that is a 2D array of np.bool_");
throw std::invalid_argument("Need isinstance(mask, np.ndarray) and mask.dtype == np.bool_");
}
const pybind11::array_t<bool> &arr = pybind11::cast<pybind11::array_t<bool>>(mask);

if (arr.ndim() != 2) {
throw std::invalid_argument(
"broadcast_pauli_errors can only accept mask that is a 2D array of np.bool_");
"Need a 2d mask (first axis is qubits, second axis is simulation instances). Need len(mask.shape) "
"== 2.");
}

size_t major = arr.shape(0);
size_t minor = arr.shape(1);

if (minor != self.batch_size) {
throw std::invalid_argument(
"broadcast_pauli_errors can only accept mask that has minor shape equal to the batch_size");
pybind11::ssize_t s_mask_num_qubits = arr.shape(0);
pybind11::ssize_t s_mask_batch_size = arr.shape(1);
if ((uint64_t)s_mask_batch_size != self.batch_size) {
throw std::invalid_argument("Need mask.shape[1] == flip_sim.batch_size");
}
if (s_mask_num_qubits > UINT32_MAX) {
throw std::invalid_argument("Mask exceeds maximum number of simulated qubits.");
}
uint32_t mask_num_qubits = (uint32_t)s_mask_num_qubits;
uint32_t mask_batch_size = (uint32_t)s_mask_batch_size;

self.ensure_safe_to_do_circuit_with_stats(CircuitStats{.num_qubits=(uint32_t)major});

self.ensure_safe_to_do_circuit_with_stats(CircuitStats{.num_qubits = mask_num_qubits});
auto u = arr.unchecked<2>();
for (size_t i = 0; i < major; i++){
for (size_t j = 0; j < minor; j++){
auto b = u.data(i, j);
self.x_table[i][j] ^= *b & flip_x_part;
self.z_table[i][j] ^= *b & flip_z_part;
bool p_x = (0b0110 >> p) & 1; // parity of 2 bit number
bool p_z = p & 2;
for (size_t i = 0; i < mask_num_qubits; i++) {
for (size_t j = 0; j < mask_batch_size; j++) {
bool b = *u.data(i, j);
self.x_table[i][j] ^= b & p_x;
self.z_table[i][j] ^= b & p_z;
}
}
},
Expand Down
10 changes: 5 additions & 5 deletions src/stim/simulators/frame_simulator_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ def test_set_pauli_flip():
stim.PauliString('XZ_'),
]

with pytest.raises(ValueError, match='Expected pauli'):
with pytest.raises(ValueError, match='pauli'):
sim.set_pauli_flip(-1, qubit_index=0, instance_index=0)
with pytest.raises(ValueError, match='Expected pauli'):
with pytest.raises(ValueError, match='pauli'):
sim.set_pauli_flip(4, qubit_index=0, instance_index=0)
with pytest.raises(ValueError, match='Expected pauli'):
with pytest.raises(ValueError, match='pauli'):
sim.set_pauli_flip('R', qubit_index=0, instance_index=0)
with pytest.raises(ValueError, match='Expected pauli'):
with pytest.raises(ValueError, match='pauli'):
sim.set_pauli_flip('XY', qubit_index=0, instance_index=0)
with pytest.raises(ValueError, match='Expected pauli'):
with pytest.raises(ValueError, match='pauli'):
sim.set_pauli_flip(object(), qubit_index=0, instance_index=0)

with pytest.raises(IndexError, match='instance_index'):
Expand Down

0 comments on commit 03c4306

Please sign in to comment.