Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use enum class instead of enum, and autoformat the code #652

Merged
merged 5 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/stim/benchmark_util.perf.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ inline void add_benchmark(RegisteredBenchmark benchmark) {
all_registered_benchmarks_data->push_back(benchmark);
}

#define BENCHMARK(name) \
void BENCH_##name##_METHOD(); \
struct BENCH_STARTUP_TYPE_##name { \
BENCH_STARTUP_TYPE_##name() { \
add_benchmark({#name, BENCH_##name##_METHOD}); \
} \
}; \
static BENCH_STARTUP_TYPE_##name BENCH_STARTUP_INSTANCE_##name; \
#define BENCHMARK(name) \
void BENCH_##name##_METHOD(); \
struct BENCH_STARTUP_TYPE_##name { \
BENCH_STARTUP_TYPE_##name() { \
add_benchmark({#name, BENCH_##name##_METHOD}); \
} \
}; \
static BENCH_STARTUP_TYPE_##name BENCH_STARTUP_INSTANCE_##name; \
void BENCH_##name##_METHOD()

// HACK: Templating the body function type makes inlining significantly more likely.
Expand Down
24 changes: 12 additions & 12 deletions src/stim/circuit/circuit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

using namespace stim;

enum READ_CONDITION {
enum class READ_CONDITION {
READ_AS_LITTLE_AS_POSSIBLE,
READ_UNTIL_END_OF_BLOCK,
READ_UNTIL_END_OF_FILE,
Expand Down Expand Up @@ -351,13 +351,13 @@ void circuit_read_operations(Circuit &circuit, SOURCE read_char, READ_CONDITION
int c = read_char();
read_past_dead_space_between_commands(c, read_char);
if (c == EOF) {
if (read_condition == READ_UNTIL_END_OF_BLOCK) {
if (read_condition == READ_CONDITION::READ_UNTIL_END_OF_BLOCK) {
throw std::invalid_argument("Unterminated block. Got a '{' without an eventual '}'.");
}
return;
}
if (c == '}') {
if (read_condition != READ_UNTIL_END_OF_BLOCK) {
if (read_condition != READ_CONDITION::READ_UNTIL_END_OF_BLOCK) {
throw std::invalid_argument("Uninitiated block. Got a '}' without a '{'.");
}
return;
Expand All @@ -378,7 +378,7 @@ void circuit_read_operations(Circuit &circuit, SOURCE read_char, READ_CONDITION

// Read block.
circuit.blocks.emplace_back();
circuit_read_operations(circuit.blocks.back(), read_char, READ_UNTIL_END_OF_BLOCK);
circuit_read_operations(circuit.blocks.back(), read_char, READ_CONDITION::READ_UNTIL_END_OF_BLOCK);

// Rewrite target data to reference the parsed block.
circuit.target_buf.ensure_available(3);
Expand All @@ -390,7 +390,7 @@ void circuit_read_operations(Circuit &circuit, SOURCE read_char, READ_CONDITION

// Fuse operations.
circuit.try_fuse_last_two_ops();
} while (read_condition != READ_AS_LITTLE_AS_POSSIBLE);
} while (read_condition != READ_CONDITION::READ_AS_LITTLE_AS_POSSIBLE);
}

void Circuit::append_from_text(const char *text) {
Expand All @@ -400,7 +400,7 @@ void Circuit::append_from_text(const char *text) {
[&]() {
return text[k] != 0 ? text[k++] : EOF;
},
READ_UNTIL_END_OF_FILE);
READ_CONDITION::READ_UNTIL_END_OF_FILE);
}

void Circuit::safe_append(const CircuitInstruction &operation) {
Expand Down Expand Up @@ -433,7 +433,7 @@ void Circuit::safe_append_u(
}

void Circuit::safe_append(GateType gate_type, SpanRef<const GateTarget> targets, SpanRef<const double> args) {
auto flags = GATE_DATA.items[gate_type].flags;
auto flags = GATE_DATA[gate_type].flags;
if (flags & GATE_IS_BLOCK) {
throw std::invalid_argument("Can't append a block like a normal operation.");
}
Expand All @@ -460,11 +460,11 @@ void Circuit::append_from_file(FILE *file, bool stop_asap) {
[&]() {
return getc(file);
},
stop_asap ? READ_AS_LITTLE_AS_POSSIBLE : READ_UNTIL_END_OF_FILE);
stop_asap ? READ_CONDITION::READ_AS_LITTLE_AS_POSSIBLE : READ_CONDITION::READ_UNTIL_END_OF_FILE);
}

std::ostream &stim::operator<<(std::ostream &out, const CircuitInstruction &instruction) {
out << GATE_DATA.items[instruction.gate_type].name;
out << GATE_DATA[instruction.gate_type].name;
if (!instruction.args.empty()) {
out << '(';
bool first = true;
Expand Down Expand Up @@ -762,7 +762,7 @@ const Circuit Circuit::aliased_noiseless_circuit() const {
// HACK: result has pointers into `circuit`!
Circuit result;
for (const auto &op : operations) {
auto flags = GATE_DATA.items[op.gate_type].flags;
auto flags = GATE_DATA[op.gate_type].flags;
if (flags & GATE_PRODUCES_RESULTS) {
if (op.gate_type == GateType::HERALDED_ERASE || op.gate_type == GateType::HERALDED_PAULI_CHANNEL_1) {
// Replace heralded errors with fixed MPAD.
Expand Down Expand Up @@ -794,7 +794,7 @@ const Circuit Circuit::aliased_noiseless_circuit() const {
Circuit Circuit::without_noise() const {
Circuit result;
for (const auto &op : operations) {
auto flags = GATE_DATA.items[op.gate_type].flags;
auto flags = GATE_DATA[op.gate_type].flags;
if (flags & GATE_PRODUCES_RESULTS) {
if (op.gate_type == GateType::HERALDED_ERASE || op.gate_type == GateType::HERALDED_PAULI_CHANNEL_1) {
// Replace heralded errors with fixed MPAD.
Expand Down Expand Up @@ -886,7 +886,7 @@ Circuit Circuit::inverse(bool allow_weak_inverse) const {
}

SpanRef<const double> args = op.args;
const auto &gate_data = GATE_DATA.items[op.gate_type];
const auto &gate_data = GATE_DATA[op.gate_type];
auto flags = gate_data.flags;
if (flags & GATE_IS_UNITARY) {
// Unitary gates always have an inverse.
Expand Down
2 changes: 1 addition & 1 deletion src/stim/circuit/circuit.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_<Ci
for (auto t : op.args) {
args.append(t);
}
const auto &gate_data = GATE_DATA.items[op.gate_type];
const auto &gate_data = GATE_DATA[op.gate_type];
if (op.args.empty()) {
// Backwards compatibility.
result.append(pybind11::make_tuple(gate_data.name, targets, 0));
Expand Down
2 changes: 1 addition & 1 deletion src/stim/circuit/circuit.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,7 @@ Circuit stim::generate_test_circuit_with_all_operations() {

TEST(circuit, generate_test_circuit_with_all_operations) {
auto c = generate_test_circuit_with_all_operations();
std::set<GateType> seen{NOT_A_GATE};
std::set<GateType> seen{GateType::NOT_A_GATE};
for (const auto &instruction : c.operations) {
seen.insert(instruction.gate_type);
}
Expand Down
8 changes: 4 additions & 4 deletions src/stim/circuit/circuit_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ CircuitStats CircuitInstruction::compute_stats(const Circuit *host) const {
}

void CircuitInstruction::add_stats_to(CircuitStats &out, const Circuit *host) const {
if (gate_type == REPEAT) {
if (gate_type == GateType::REPEAT) {
if (host == nullptr) {
throw std::invalid_argument("gate_type == REPEAT && host == nullptr");
}
Expand Down Expand Up @@ -110,7 +110,7 @@ CircuitInstruction::CircuitInstruction(
}

void CircuitInstruction::validate() const {
const Gate &gate = GATE_DATA.items[gate_type];
const Gate &gate = GATE_DATA[gate_type];

if (gate.flags == GateFlags::NO_GATE_FLAG) {
throw std::invalid_argument("Unrecognized gate_type. Associated flag is NO_GATE_FLAG.");
Expand Down Expand Up @@ -248,7 +248,7 @@ void CircuitInstruction::validate() const {
}

uint64_t CircuitInstruction::count_measurement_results() const {
auto flags = GATE_DATA.items[gate_type].flags;
auto flags = GATE_DATA[gate_type].flags;
if (!(flags & GATE_PRODUCES_RESULTS)) {
return 0;
}
Expand All @@ -266,7 +266,7 @@ uint64_t CircuitInstruction::count_measurement_results() const {
}

bool CircuitInstruction::can_fuse(const CircuitInstruction &other) const {
auto flags = GATE_DATA.items[gate_type].flags;
auto flags = GATE_DATA[gate_type].flags;
return gate_type == other.gate_type && args == other.args && !(flags & GATE_IS_NOT_FUSABLE);
}

Expand Down
2 changes: 1 addition & 1 deletion src/stim/circuit/circuit_instruction.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ PyCircuitInstruction::operator CircuitInstruction() const {
return as_operation_ref();
}
std::string PyCircuitInstruction::name() const {
return GATE_DATA.items[gate_type].name;
return GATE_DATA[gate_type].name;
}
std::vector<uint32_t> PyCircuitInstruction::raw_targets() const {
std::vector<uint32_t> result;
Expand Down
12 changes: 5 additions & 7 deletions src/stim/circuit/gate_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#include <complex>

#include "stim/circuit/stabilizer_flow.h"

using namespace stim;

GateDataMap::GateDataMap() {
Expand Down Expand Up @@ -76,7 +74,7 @@ std::vector<std::vector<std::complex<float>>> Gate::unitary() const {
const Gate &Gate::inverse() const {
std::string inv_name = name;
if ((flags & GATE_IS_UNITARY) || id == GateType::TICK) {
return GATE_DATA.items[static_cast<uint8_t>(best_candidate_inverse_id)];
return GATE_DATA[best_candidate_inverse_id];
}
throw std::out_of_range(inv_name + " has no inverse.");
}
Expand All @@ -101,16 +99,16 @@ Gate::Gate(
}

void GateDataMap::add_gate(bool &failed, const Gate &gate) {
assert(gate.id < NUM_DEFINED_GATES);
assert((size_t)gate.id < NUM_DEFINED_GATES);
const char *c = gate.name;
auto h = gate_name_to_hash(c);
auto &hash_loc = hashed_name_to_gate_type_table[h];
if (hash_loc.expected_name_len != 0) {
std::cerr << "GATE COLLISION " << gate.name << " vs " << items[hash_loc.id].name << "\n";
std::cerr << "GATE COLLISION " << gate.name << " vs " << items[(size_t)hash_loc.id].name << "\n";
failed = true;
return;
}
items[gate.id] = gate;
items[(size_t)gate.id] = gate;
hash_loc.id = gate.id;
hash_loc.expected_name = gate.name;
hash_loc.expected_name_len = gate.name_len;
Expand All @@ -120,7 +118,7 @@ void GateDataMap::add_gate_alias(bool &failed, const char *alt_name, const char
auto h_alt = gate_name_to_hash(alt_name);
auto &hash_loc = hashed_name_to_gate_type_table[h_alt];
if (hash_loc.expected_name_len != 0) {
std::cerr << "GATE COLLISION " << alt_name << " vs " << items[hash_loc.id].name << "\n";
std::cerr << "GATE COLLISION " << alt_name << " vs " << items[(size_t)hash_loc.id].name << "\n";
failed = true;
return;
}
Expand Down
14 changes: 9 additions & 5 deletions src/stim/circuit/gate_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ constexpr inline uint16_t gate_name_to_hash(const char *c) {

constexpr const size_t NUM_DEFINED_GATES = 67;

enum GateType : uint8_t {
enum class GateType : uint8_t {
NOT_A_GATE = 0,
// Annotations
DETECTOR,
Expand Down Expand Up @@ -258,7 +258,7 @@ struct Gate {

template <size_t W>
Tableau<W> tableau() const {
if (!(flags & GATE_IS_UNITARY)) {
if (!(flags & GateFlags::GATE_IS_UNITARY)) {
throw std::invalid_argument(std::string(name) + " isn't unitary so it doesn't have a tableau.");
}
const auto &tableau_data = extra_data_func().flow_data;
Expand All @@ -274,9 +274,9 @@ struct Gate {

template <size_t W>
std::vector<StabilizerFlow<W>> flows() const {
if (flags & GATE_IS_UNITARY) {
if (flags & GateFlags::GATE_IS_UNITARY) {
auto t = tableau<W>();
if (flags & GATE_TARGETS_PAIRS) {
if (flags & GateFlags::GATE_TARGETS_PAIRS) {
return {
StabilizerFlow<W>{stim::PauliString<W>::from_str("X_"), t.xs[0], {}},
StabilizerFlow<W>{stim::PauliString<W>::from_str("Z_"), t.zs[0], {}},
Expand Down Expand Up @@ -340,14 +340,18 @@ struct GateDataMap {
std::array<Gate, NUM_DEFINED_GATES> items;
GateDataMap();

inline const Gate &operator[](GateType g) const {
return items[(uint64_t)g];
}

inline const Gate &at(const char *text, size_t text_len) const {
auto h = gate_name_to_hash(text, text_len);
const auto &entry = hashed_name_to_gate_type_table[h];
if (_case_insensitive_mismatch(text, text_len, entry.expected_name, entry.expected_name_len)) {
throw std::out_of_range("Gate not found: '" + std::string(text, text_len) + "'");
}
// Canonicalize.
return items[entry.id];
return (*this)[entry.id];
}

inline const Gate &at(const char *text) const {
Expand Down
2 changes: 1 addition & 1 deletion src/stim/circuit/gate_data.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ void stim_pybind::pybind_gate_data_methods(pybind11::module &m, pybind11::class_

std::map<std::string, Gate> result;
for (const auto &g : GATE_DATA.items) {
if (g.id != NOT_A_GATE) {
if (g.id != GateType::NOT_A_GATE) {
result.insert({g.name, g});
}
}
Expand Down
28 changes: 15 additions & 13 deletions src/stim/circuit/gate_data.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,34 @@ TEST(gate_data, lookup) {
}

TEST(gate_data, zero_flag_means_not_a_gate) {
ASSERT_EQ(GATE_DATA.items[0].id, 0);
ASSERT_EQ(GATE_DATA.items[0].flags, GateFlags::NO_GATE_FLAG);
ASSERT_EQ((GateType)0, GateType::NOT_A_GATE);
ASSERT_EQ(GATE_DATA[(GateType)0].id, (GateType)0);
ASSERT_EQ(GATE_DATA[(GateType)0].flags, GateFlags::NO_GATE_FLAG);
for (size_t k = 0; k < GATE_DATA.items.size(); k++) {
const auto &g = GATE_DATA.items[k];
if (g.id != 0) {
const auto &g = GATE_DATA[(GateType)k];
if (g.id != GateType::NOT_A_GATE) {
EXPECT_NE(g.flags, GateFlags::NO_GATE_FLAG) << g.name;
}
}
}

TEST(gate_data, one_step_to_canonical_gate) {
for (size_t k = 0; k < GATE_DATA.items.size(); k++) {
const auto &g = GATE_DATA.items[k];
if (g.id != 0) {
EXPECT_TRUE(g.id == k || GATE_DATA.items[g.id].id == g.id) << g.name;
const auto &g = GATE_DATA[(GateType)k];
if (g.id != GateType::NOT_A_GATE) {
EXPECT_TRUE(g.id == (GateType)k || GATE_DATA[g.id].id == g.id) << g.name;
}
}
}

TEST(gate_data, hash_matches_storage_location) {
ASSERT_EQ(GATE_DATA.items[0].id, 0);
ASSERT_EQ(GATE_DATA.items[0].flags, GateFlags::NO_GATE_FLAG);
ASSERT_EQ((GateType)0, GateType::NOT_A_GATE);
ASSERT_EQ(GATE_DATA[(GateType)0].id, (GateType)0);
ASSERT_EQ(GATE_DATA[(GateType)0].flags, GateFlags::NO_GATE_FLAG);
for (size_t k = 0; k < GATE_DATA.items.size(); k++) {
const auto &g = GATE_DATA.items[k];
EXPECT_EQ(g.id, k) << g.name;
if (g.id != 0) {
const auto &g = GATE_DATA[(GateType)k];
EXPECT_EQ(g.id, (GateType)k) << g.name;
if (g.id != GateType::NOT_A_GATE) {
EXPECT_EQ(GATE_DATA.hashed_name_to_gate_type_table[gate_name_to_hash(g.name)].id, g.id) << g.name;
}
}
Expand Down Expand Up @@ -132,7 +134,7 @@ TEST_EACH_WORD_SIZE_W(gate_data, unitary_inverses_are_correct, {
for (const auto &g : GATE_DATA.items) {
if (g.flags & GATE_IS_UNITARY) {
auto g_t_inv = g.tableau<W>().inverse(false);
auto g_inv_t = GATE_DATA.items[static_cast<uint8_t>(g.best_candidate_inverse_id)].tableau<W>();
auto g_inv_t = GATE_DATA[g.best_candidate_inverse_id].tableau<W>();
EXPECT_EQ(g_t_inv, g_inv_t) << g.name;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/stim/circuit/gate_data_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ struct GateVTable {
#ifndef NDEBUG
std::array<bool, NUM_DEFINED_GATES> seen{};
for (const auto &[gate_id, value] : gate_data_pairs) {
seen[gate_id] = true;
seen[(size_t)gate_id] = true;
}
for (const auto &gate : GATE_DATA.items) {
if (!seen[gate.id]) {
if (!seen[(size_t)gate.id]) {
throw std::invalid_argument(
"Missing gate data! A value was not defined for '" + std::string(gate.name) + "'.");
}
Expand Down
6 changes: 3 additions & 3 deletions src/stim/circuit/gate_decomposition.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ TEST(gate_decomposition, decompose_pair_instruction_into_segments_with_single_us
for (size_t k = 0; k < segment.targets.size(); k += 2) {
evens.push_back(segment.targets[k]);
}
out.safe_append(CircuitInstruction{stim::CX, {}, segment.targets});
out.safe_append(CircuitInstruction{stim::MX, segment.args, evens});
out.safe_append(CircuitInstruction{stim::CX, {}, segment.targets});
out.safe_append(CircuitInstruction{GateType::CX, {}, segment.targets});
out.safe_append(CircuitInstruction{GateType::MX, segment.args, evens});
out.safe_append(CircuitInstruction{GateType::CX, {}, segment.targets});
out.append_from_text("TICK");
};
decompose_pair_instruction_into_segments_with_single_use_controls(
Expand Down
2 changes: 1 addition & 1 deletion src/stim/cmd/command_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ int stim::command_convert(int argc, const char **argv) {
// convert arbitrary bits.
if (!details.include_measurements && !details.include_detectors && !details.include_observables) {
// dets outputs explicit value types, which we don't know if we get here.
if (out_format.id == SAMPLE_FORMAT_DETS) {
if (out_format.id == SampleFormat::SAMPLE_FORMAT_DETS) {
std::cerr
<< "\033[31mNot enough information given to parse input file to write to dets. Please given a circuit "
"with --types, a DEM file, or explicit number of each desired type\n";
Expand Down
2 changes: 1 addition & 1 deletion src/stim/cmd/command_detect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int stim::command_detect(int argc, const char **argv) {
find_argument("--shots", argc, argv) ? (uint64_t)find_int64_argument("--shots", 1, 0, INT64_MAX, argc, argv)
: find_argument("--detect", argc, argv) ? (uint64_t)find_int64_argument("--detect", 1, 0, INT64_MAX, argc, argv)
: 1;
if (out_format.id == SAMPLE_FORMAT_DETS && !append_observables) {
if (out_format.id == SampleFormat::SAMPLE_FORMAT_DETS && !append_observables) {
prepend_observables = true;
}

Expand Down
Loading
Loading