From ec74f20a9072ddeb7f31175de30b3cf7fcdeca99 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Fri, 27 May 2022 16:11:15 -0700 Subject: [PATCH 1/7] Add support for repeated keys in QSimSimulator --- qsimcirq/qsim_simulator.py | 33 +++++++++++++++++++-------------- qsimcirq_tests/qsimcirq_test.py | 16 ++++++++++++++++ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/qsimcirq/qsim_simulator.py b/qsimcirq/qsim_simulator.py index ee34d67f..3353f959 100644 --- a/qsimcirq/qsim_simulator.py +++ b/qsimcirq/qsim_simulator.py @@ -323,8 +323,8 @@ def _sample_measure_results( cirq.MeasurementGate ) ] - measured_qubits: List[cirq.Qid] = [] - bounds: Dict[str, Tuple] = {} + num_qubits_by_key: Dict[str, int] = {} + bounds: Dict[str, Tuple[int, int]] = {} meas_ops: Dict[str, List[cirq.GateOperation]] = {} current_index = 0 for op in measurement_ops: @@ -332,21 +332,26 @@ def _sample_measure_results( key = cirq.measurement_key_name(gate) meas_ops.setdefault(key, []) meas_ops[key].append(op) - if key in bounds: - raise ValueError(f"Duplicate MeasurementGate with key {key}") - bounds[key] = (current_index, current_index + len(op.qubits)) - measured_qubits.extend(op.qubits) - current_index += len(op.qubits) + n = len(op.qubits) + if key in num_qubits_by_key: + if n != num_qubits_by_key[key]: + raise ValueError( + f'repeated key {key!r} with different numbers of qubits: ' + f'{num_qubits_by_key[key]} != {n}' + ) + else: + num_qubits_by_key[key] = n + if key not in bounds: + bounds[key] = (current_index, current_index + n) + current_index += n # Set qsim options - options = {} - options.update(self.qsim_options) + options = {**self.qsim_options} - results = {} - for key, bound in bounds.items(): - results[key] = np.ndarray( - shape=(repetitions, len(meas_ops[key]), bound[1] - bound[0]), dtype=int - ) + results = { + key: np.ndarray(shape=(repetitions, len(meas_ops[key]), n), dtype=int) + for key, n in num_qubits_by_key.items() + } noisy = _needs_trajectories(program) if not noisy and program.are_all_measurements_terminal() and repetitions > 1: diff --git a/qsimcirq_tests/qsimcirq_test.py b/qsimcirq_tests/qsimcirq_test.py index 6e516676..e0d84bad 100644 --- a/qsimcirq_tests/qsimcirq_test.py +++ b/qsimcirq_tests/qsimcirq_test.py @@ -57,6 +57,22 @@ def test_empty_moment(mode: str): assert result.final_state_vector.shape == (4,) +def test_repeated_keys(): + q = cirq.LineQubit(0) + circuit = cirq.Circuit( + cirq.measure(q, key='m'), + cirq.X(q), + cirq.measure(q, key='m'), + cirq.X(q), + cirq.measure(q, key='m'), + ) + result = qsimcirq.QSimSimulator().simulate(circuit, repetitions=10) + assert result.records['m'] == (10, 3, 1) + assert np.all(result.records['m'][:, 0, :] == 0) + assert np.all(result.records['m'][:, 1, :] == 1) + assert np.all(result.records['m'][:, 2, :] == 0) + + def test_cirq_too_big_gate(): # Pick qubits. a, b, c, d, e, f, g = [ From e3068c49a80506c6fd72c00e74e148e7868a5e3b Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Fri, 27 May 2022 16:23:48 -0700 Subject: [PATCH 2/7] Format --- qsimcirq/qsim_simulator.py | 4 ++-- qsimcirq_tests/qsimcirq_test.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/qsimcirq/qsim_simulator.py b/qsimcirq/qsim_simulator.py index 3353f959..d6b3c1e5 100644 --- a/qsimcirq/qsim_simulator.py +++ b/qsimcirq/qsim_simulator.py @@ -336,8 +336,8 @@ def _sample_measure_results( if key in num_qubits_by_key: if n != num_qubits_by_key[key]: raise ValueError( - f'repeated key {key!r} with different numbers of qubits: ' - f'{num_qubits_by_key[key]} != {n}' + f"repeated key {key!r} with different numbers of qubits: " + f"{num_qubits_by_key[key]} != {n}" ) else: num_qubits_by_key[key] = n diff --git a/qsimcirq_tests/qsimcirq_test.py b/qsimcirq_tests/qsimcirq_test.py index e0d84bad..5320865b 100644 --- a/qsimcirq_tests/qsimcirq_test.py +++ b/qsimcirq_tests/qsimcirq_test.py @@ -60,17 +60,17 @@ def test_empty_moment(mode: str): def test_repeated_keys(): q = cirq.LineQubit(0) circuit = cirq.Circuit( - cirq.measure(q, key='m'), + cirq.measure(q, key="m"), cirq.X(q), - cirq.measure(q, key='m'), + cirq.measure(q, key="m"), cirq.X(q), - cirq.measure(q, key='m'), + cirq.measure(q, key="m"), ) result = qsimcirq.QSimSimulator().simulate(circuit, repetitions=10) - assert result.records['m'] == (10, 3, 1) - assert np.all(result.records['m'][:, 0, :] == 0) - assert np.all(result.records['m'][:, 1, :] == 1) - assert np.all(result.records['m'][:, 2, :] == 0) + assert result.records["m"] == (10, 3, 1) + assert np.all(result.records["m"][:, 0, :] == 0) + assert np.all(result.records["m"][:, 1, :] == 1) + assert np.all(result.records["m"][:, 2, :] == 0) def test_cirq_too_big_gate(): From 6d0397993ec36b77519166d46d2c32dc12705b9b Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Fri, 27 May 2022 16:25:43 -0700 Subject: [PATCH 3/7] s/simulate/run/ --- qsimcirq_tests/qsimcirq_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qsimcirq_tests/qsimcirq_test.py b/qsimcirq_tests/qsimcirq_test.py index 5320865b..82651df7 100644 --- a/qsimcirq_tests/qsimcirq_test.py +++ b/qsimcirq_tests/qsimcirq_test.py @@ -66,7 +66,7 @@ def test_repeated_keys(): cirq.X(q), cirq.measure(q, key="m"), ) - result = qsimcirq.QSimSimulator().simulate(circuit, repetitions=10) + result = qsimcirq.QSimSimulator().run(circuit, repetitions=10) assert result.records["m"] == (10, 3, 1) assert np.all(result.records["m"][:, 0, :] == 0) assert np.all(result.records["m"][:, 1, :] == 1) From 99d4ce9e8e19749e6c19214557ecde3d9aa26b9f Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 31 May 2022 09:26:30 -0700 Subject: [PATCH 4/7] Store bounds for each instance of a measurement key --- qsimcirq/qsim_simulator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/qsimcirq/qsim_simulator.py b/qsimcirq/qsim_simulator.py index d6b3c1e5..06df3f8f 100644 --- a/qsimcirq/qsim_simulator.py +++ b/qsimcirq/qsim_simulator.py @@ -324,7 +324,7 @@ def _sample_measure_results( ) ] num_qubits_by_key: Dict[str, int] = {} - bounds: Dict[str, Tuple[int, int]] = {} + bounds: List[Tuple[str, int, int]] = [] meas_ops: Dict[str, List[cirq.GateOperation]] = {} current_index = 0 for op in measurement_ops: @@ -341,8 +341,7 @@ def _sample_measure_results( ) else: num_qubits_by_key[key] = n - if key not in bounds: - bounds[key] = (current_index, current_index + n) + bounds.append((key, current_index, current_index + n)) current_index += n # Set qsim options @@ -414,7 +413,7 @@ def _sample_measure_results( options["s"] = self.get_seed() measurements[i] = sampler_fn(options) - for key, (start, end) in bounds.items(): + for key, start, end in bounds: for i, op in enumerate(meas_ops[key]): invert_mask = op.gate.full_invert_mask() results[key][:, i, :] = measurements[:, start:end] ^ invert_mask From 4b828c38ab78c02fc889a16a0f1f64a7d3db06bd Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 31 May 2022 10:03:32 -0700 Subject: [PATCH 5/7] Fix bounds calculation --- qsimcirq/qsim_simulator.py | 38 +++++++++++++-------------------- qsimcirq_tests/qsimcirq_test.py | 2 +- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/qsimcirq/qsim_simulator.py b/qsimcirq/qsim_simulator.py index 06df3f8f..78a85e92 100644 --- a/qsimcirq/qsim_simulator.py +++ b/qsimcirq/qsim_simulator.py @@ -313,10 +313,13 @@ def _sample_measure_results( qubit_map = {qubit: index for index, qubit in enumerate(ordered_qubits)} - # Computes - # - the list of qubits to be measured - # - the start (inclusive) and end (exclusive) indices of each measurement - # - a mapping from measurement key to measurement gate + # Compute: + # - number of qubits for each measurement key. + # - measurement ops for each measurement key. + # - info about each measurement, including the key, instance (for repeated keys), + # start (inclusive) and end (exclusive) indices in the qsim output, and + # invert mask. + # - total number of measured bits. measurement_ops = [ op for _, op, _ in program.findall_operations_with_gate_type( @@ -324,13 +327,14 @@ def _sample_measure_results( ) ] num_qubits_by_key: Dict[str, int] = {} - bounds: List[Tuple[str, int, int]] = [] meas_ops: Dict[str, List[cirq.GateOperation]] = {} - current_index = 0 + info: List[Tuple[str, int, Tuple[bool, ...], int, int]] = [] + num_bits = 0 for op in measurement_ops: gate = op.gate key = cirq.measurement_key_name(gate) meas_ops.setdefault(key, []) + i = len(meas_ops[key]) meas_ops[key].append(op) n = len(op.qubits) if key in num_qubits_by_key: @@ -341,8 +345,8 @@ def _sample_measure_results( ) else: num_qubits_by_key[key] = n - bounds.append((key, current_index, current_index + n)) - current_index += n + info.append((key, i, gate.full_invert_mask(), num_bits, num_bits + n)) + num_bits += n # Set qsim options options = {**self.qsim_options} @@ -398,25 +402,13 @@ def _sample_measure_results( translator_fn_name, cirq.QubitOrder.DEFAULT, ) - measurements = np.empty( - shape=( - repetitions, - sum( - cirq.num_qubits(op) - for oplist in meas_ops.values() - for op in oplist - ), - ), - dtype=int, - ) + measurements = np.empty(shape=(repetitions, num_bits), dtype=int) for i in range(repetitions): options["s"] = self.get_seed() measurements[i] = sampler_fn(options) - for key, start, end in bounds: - for i, op in enumerate(meas_ops[key]): - invert_mask = op.gate.full_invert_mask() - results[key][:, i, :] = measurements[:, start:end] ^ invert_mask + for key, i, invert_mask, start, end in info: + results[key][:, i, :] = measurements[:, start:end] ^ invert_mask return results diff --git a/qsimcirq_tests/qsimcirq_test.py b/qsimcirq_tests/qsimcirq_test.py index 82651df7..ec8c4a23 100644 --- a/qsimcirq_tests/qsimcirq_test.py +++ b/qsimcirq_tests/qsimcirq_test.py @@ -67,7 +67,7 @@ def test_repeated_keys(): cirq.measure(q, key="m"), ) result = qsimcirq.QSimSimulator().run(circuit, repetitions=10) - assert result.records["m"] == (10, 3, 1) + assert result.records["m"].shape == (10, 3, 1) assert np.all(result.records["m"][:, 0, :] == 0) assert np.all(result.records["m"][:, 1, :] == 1) assert np.all(result.records["m"][:, 2, :] == 0) From 8cdd0bfbc8915f41595150b210cdb489265d299d Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 31 May 2022 11:21:15 -0700 Subject: [PATCH 6/7] Remove unused imports --- qsimcirq/qsim_simulator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/qsimcirq/qsim_simulator.py b/qsimcirq/qsim_simulator.py index 78a85e92..b15bee0a 100644 --- a/qsimcirq/qsim_simulator.py +++ b/qsimcirq/qsim_simulator.py @@ -14,8 +14,7 @@ from collections import deque from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union -from xml.etree.ElementPath import ops +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import cirq From ece373cb718bd7bae04607031a298c299d73f519 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Thu, 2 Jun 2022 14:38:54 -0700 Subject: [PATCH 7/7] Fixes from review --- qsimcirq/qsim_simulator.py | 41 +++++++++++++++++++++++++------ qsimcirq_tests/qsimcirq_test.py | 43 +++++++++++++++++++++++++++------ 2 files changed, 69 insertions(+), 15 deletions(-) diff --git a/qsimcirq/qsim_simulator.py b/qsimcirq/qsim_simulator.py index b15bee0a..5835c574 100644 --- a/qsimcirq/qsim_simulator.py +++ b/qsimcirq/qsim_simulator.py @@ -167,6 +167,25 @@ def as_dict(self): } +@dataclass +class MeasInfo: + """Info about each measure operation in the circuit being simulated. + + Attributes: + key: The measurement key. + idx: The "instance" of a possibly-repeated measurement key. + invert_mask: True for any measurement bits that should be inverted. + start: Start index in qsim's output array for this measurement. + end: End index (non-inclusive) in qsim's output array. + """ + + key: str + idx: int + invert_mask: Tuple[bool, ...] + start: int + end: int + + class QSimSimulator( cirq.SimulatesSamples, cirq.SimulatesAmplitudes, @@ -315,9 +334,7 @@ def _sample_measure_results( # Compute: # - number of qubits for each measurement key. # - measurement ops for each measurement key. - # - info about each measurement, including the key, instance (for repeated keys), - # start (inclusive) and end (exclusive) indices in the qsim output, and - # invert mask. + # - measurement info for each measurement. # - total number of measured bits. measurement_ops = [ op @@ -327,7 +344,7 @@ def _sample_measure_results( ] num_qubits_by_key: Dict[str, int] = {} meas_ops: Dict[str, List[cirq.GateOperation]] = {} - info: List[Tuple[str, int, Tuple[bool, ...], int, int]] = [] + meas_infos: List[MeasInfo] = [] num_bits = 0 for op in measurement_ops: gate = op.gate @@ -344,7 +361,15 @@ def _sample_measure_results( ) else: num_qubits_by_key[key] = n - info.append((key, i, gate.full_invert_mask(), num_bits, num_bits + n)) + meas_infos.append( + MeasInfo( + key=key, + idx=i, + invert_mask=gate.full_invert_mask(), + start=num_bits, + end=num_bits + n, + ) + ) num_bits += n # Set qsim options @@ -406,8 +431,10 @@ def _sample_measure_results( options["s"] = self.get_seed() measurements[i] = sampler_fn(options) - for key, i, invert_mask, start, end in info: - results[key][:, i, :] = measurements[:, start:end] ^ invert_mask + for m in meas_infos: + results[m.key][:, m.idx, :] = ( + measurements[:, m.start : m.end] ^ m.invert_mask + ) return results diff --git a/qsimcirq_tests/qsimcirq_test.py b/qsimcirq_tests/qsimcirq_test.py index ec8c4a23..ce439f29 100644 --- a/qsimcirq_tests/qsimcirq_test.py +++ b/qsimcirq_tests/qsimcirq_test.py @@ -58,19 +58,46 @@ def test_empty_moment(mode: str): def test_repeated_keys(): - q = cirq.LineQubit(0) + q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( - cirq.measure(q, key="m"), - cirq.X(q), - cirq.measure(q, key="m"), - cirq.X(q), - cirq.measure(q, key="m"), + cirq.Moment(cirq.measure(q0, key="m")), + cirq.Moment(cirq.X(q1)), + cirq.Moment(cirq.measure(q1, key="m")), + cirq.Moment(cirq.X(q0)), + cirq.Moment(cirq.measure(q0, key="m")), + cirq.Moment(cirq.X(q1)), + cirq.Moment(cirq.measure(q1, key="m")), ) result = qsimcirq.QSimSimulator().run(circuit, repetitions=10) - assert result.records["m"].shape == (10, 3, 1) + assert result.records["m"].shape == (10, 4, 1) assert np.all(result.records["m"][:, 0, :] == 0) assert np.all(result.records["m"][:, 1, :] == 1) - assert np.all(result.records["m"][:, 2, :] == 0) + assert np.all(result.records["m"][:, 2, :] == 1) + assert np.all(result.records["m"][:, 3, :] == 0) + + +def test_repeated_keys_same_moment(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.Moment(cirq.X(q1)), + cirq.Moment(cirq.measure(q0, key="m"), cirq.measure(q1, key="m")), + ) + result = qsimcirq.QSimSimulator().run(circuit, repetitions=10) + assert result.records["m"].shape == (10, 2, 1) + assert np.all(result.records["m"][:, 0, :] == 0) + assert np.all(result.records["m"][:, 1, :] == 1) + + +def test_repeated_keys_different_numbers_of_qubits(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key="m"), + cirq.measure(q0, q1, key="m"), + ) + with pytest.raises( + ValueError, match="repeated key 'm' with different numbers of qubits" + ): + _ = qsimcirq.QSimSimulator().run(circuit, repetitions=10) def test_cirq_too_big_gate():