Skip to content

Commit

Permalink
fix segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Mar 27, 2024
1 parent 5760f13 commit 9bec3bf
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/stim/dem/detector_error_model_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_append_bad():
m.append("shift_detectors", [], [5])
m += m * 3

with pytest.raises(ValueError, match="Bad target 'D0' for instruction 'shift_detectors'"):
with pytest.raises(ValueError, match=r"Bad target 'stim.target_relative_detector_id\(0\)' for instruction 'shift_detectors'"):
m.append("shift_detectors", [0.125, 0.25], [stim.target_relative_detector_id(0)])
with pytest.raises(ValueError, match="takes 1 argument"):
m.append("error", [0.125, 0.25], [stim.target_relative_detector_id(0)])
Expand Down
26 changes: 21 additions & 5 deletions src/stim/io/raii_file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,35 @@ RaiiFile::RaiiFile(RaiiFile &&other) noexcept : f(other.f), responsible_for_clos
other.f = nullptr;
}

RaiiFile::RaiiFile(std::string_view path, const char *mode) : f(nullptr), responsible_for_closing(true) {
if (path.empty()) {
f = nullptr;
RaiiFile::RaiiFile(const char *optional_path, const char *mode) : f(nullptr), responsible_for_closing(true) {
open(optional_path, mode);
}

RaiiFile::RaiiFile(std::string_view optional_path, const char *mode) : f(nullptr), responsible_for_closing(true) {
open(optional_path, mode);
}

void RaiiFile::open(const char *optional_path, const char *mode) {
done();
if (optional_path == nullptr) {
return;
}
open(std::string_view(optional_path), mode);
}

void RaiiFile::open(std::string_view optional_path, const char *mode) {
done();
if (optional_path.empty()) {
return;
}

// TODO: avoid needing the string copy (for null termination) to safely open the file.
f = fopen(std::string(path).c_str(), mode);
f = fopen(std::string(optional_path).c_str(), mode);

if (f == nullptr) {
std::stringstream ss;
ss << "Failed to open '";
ss << path;
ss << optional_path;
ss << "' for ";
if (*mode == 'r') {
ss << "reading.";
Expand Down
5 changes: 4 additions & 1 deletion src/stim/io/raii_file.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ namespace stim {
struct RaiiFile {
FILE* f;
bool responsible_for_closing;
RaiiFile(std::string_view path, const char* mode);
RaiiFile(const char *optional_path, const char* mode);
RaiiFile(std::string_view optional_path, const char* mode);
RaiiFile(FILE* claim_ownership);
RaiiFile(const RaiiFile& other) = delete;
RaiiFile(RaiiFile&& other) noexcept;
~RaiiFile();
void open(std::string_view optional_path, const char *mode);
void open(const char *optional_path, const char *mode);
void done();
};

Expand Down
40 changes: 36 additions & 4 deletions src/stim/py/compiled_detector_sampler.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,46 @@ pybind11::object CompiledDetectorSampler::sample_to_numpy(

void CompiledDetectorSampler::sample_write(
size_t num_samples,
std::string_view filepath,
pybind11::object filepath_obj,
std::string_view format,
bool prepend_observables,
bool append_observables,
std::string_view obs_out_filepath,
pybind11::object obs_out_filepath_obj,
std::string_view obs_out_format) {
auto f = format_to_enum(format);

auto py_path = pybind11::module::import("pathlib").attr("Path");
if (pybind11::isinstance(filepath_obj, py_path)) {
filepath_obj = pybind11::str(filepath_obj);
}
if (pybind11::isinstance(obs_out_filepath_obj, py_path)) {
obs_out_filepath_obj = pybind11::str(obs_out_filepath_obj);
}

std::string_view filepath;
if (pybind11::isinstance<pybind11::str>(filepath_obj)) {
filepath = pybind11::cast<std::string_view>(filepath_obj);
} else {
std::stringstream ss;
ss << "Don't know how to write to ";
ss << pybind11::repr(filepath_obj);
throw std::invalid_argument(ss.str());
}

std::string_view obs_out_filepath_view;
if (pybind11::isinstance<pybind11::str>(obs_out_filepath_obj)) {
obs_out_filepath_view = pybind11::cast<std::string_view>(obs_out_filepath_obj);
} else if (obs_out_filepath_obj.is_none()) {
// Empty string view does the right thing.
} else {
std::stringstream ss;
ss << "Don't know how to write observables to ";
ss << pybind11::repr(obs_out_filepath_obj);
throw std::invalid_argument(ss.str());
}

RaiiFile out(filepath, "wb");
RaiiFile obs_out(obs_out_filepath, "wb");
RaiiFile obs_out(obs_out_filepath_view, "wb");
auto parsed_obs_out_format = format_to_enum(obs_out_format);
sample_batch_detection_events_writing_results_to_disk<MAX_BITWORD_WIDTH>(
circuit,
Expand Down Expand Up @@ -291,9 +322,10 @@ void stim_pybind::pybind_compiled_detector_sampler_methods(
pybind11::arg("format") = "01",
pybind11::arg("prepend_observables") = false,
pybind11::arg("append_observables") = false,
pybind11::arg("obs_out_filepath") = nullptr,
pybind11::arg("obs_out_filepath") = pybind11::none(),
pybind11::arg("obs_out_format") = "01",
clean_doc_string(R"DOC(
@signature def sample_write(shots: int, *, filepath: Union[str, pathlib.Path], format: 'Literal["01", "b8", "r8", "ptb64", "hits", "dets"]' = '01', obs_out_filepath: Optional[Union[str, pathlib.Path]] = None, obs_out_format: 'Literal["01", "b8", "r8", "ptb64", "hits", "dets"]' = '01', prepend_observables: bool = False, append_observables: bool = False):
Samples detection events from the circuit and writes them to a file.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/stim/py/compiled_detector_sampler.pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ struct CompiledDetectorSampler {
bool bit_packed);
void sample_write(
size_t num_samples,
std::string_view filepath,
pybind11::object filepath_obj,
std::string_view format,
bool prepend_observables,
bool append_observables,
std::string_view obs_out_filepath,
pybind11::object obs_out_filepath_obj,
std::string_view obs_out_format);
std::string repr() const;
};
Expand Down
17 changes: 17 additions & 0 deletions src/stim/py/compiled_detector_sampler_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def test_compiled_detector_sampler_sample():
with open(path, 'rb') as f:
assert f.read() == b'\x03' * 5

c.compile_detector_sampler().sample_write(5, filepath=pathlib.Path(path), format='b8')
with open(path, 'rb') as f:
assert f.read() == b'\x03' * 5

c.compile_detector_sampler().sample_write(5, filepath=path, format='01', prepend_observables=True)
with open(path, 'r') as f:
assert f.readlines() == ['1000110\n'] * 5
Expand Down Expand Up @@ -182,6 +186,19 @@ def test_write_obs_file():
with open(d / 'obs') as f:
assert f.read() == '1\n' * 100

with tempfile.TemporaryDirectory() as d:
d = pathlib.Path(d)
r.sample_write(
shots=100,
filepath=d / 'det',
format='dets',
obs_out_filepath=d / 'obs',
obs_out_format='hits',
)
with open(d / 'det') as f:
assert f.read() == 'shot D3\n' * 100
with open(d / 'obs') as f:
assert f.read() == '1\n' * 100

def test_detector_sampler_actually_fills_array():
circuit = stim.Circuit('''
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void CompiledMeasurementsToDetectionEventsConverter::convert_file(
auto format_sweep_bits = format_to_enum(sweep_bits_format);
auto format_out = format_to_enum(detection_events_format);
RaiiFile file_in(measurements_filepath, "rb");
RaiiFile obs_out(obs_out_filepath, "wb");
RaiiFile obs_out(obs_out_filepath == nullptr ? std::string_view{} : std::string_view{obs_out_filepath}, "wb");
RaiiFile sweep_bits_in(sweep_bits_filepath, "rb");
RaiiFile detections_out(detection_events_filepath, "wb");
auto parsed_obs_out_format = format_to_enum(obs_out_format);
Expand Down
13 changes: 8 additions & 5 deletions src/stim/stabilizers/tableau_iter.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,21 @@ TEST_EACH_WORD_SIZE_W(tableau_iter, iter_tableau, {
while (iter1.iter_next()) {
n1++;
}
ASSERT_EQ(n1, 6);

while (iter1_signs.iter_next()) {
s1++;
}
ASSERT_EQ(s1, 24);

while (iter2.iter_next()) {
n2++;
}
while (iter3.iter_next()) {
n3++;
}
ASSERT_EQ(n1, 6);
ASSERT_EQ(s1, 24);
ASSERT_EQ(n2, 720);

// while (iter3.iter_next()) {
// n3++;
// }
// ASSERT_EQ(n3, 1451520); // Note: disabled because it takes 2-3 seconds.
})

Expand Down

0 comments on commit 9bec3bf

Please sign in to comment.