From 84d433742c35b0ea86e57909be6af657ea65ddfa Mon Sep 17 00:00:00 2001 From: Dan Mills <52407433+daniel-mills-cqc@users.noreply.github.com> Date: Tue, 10 Dec 2024 09:13:58 +0000 Subject: [PATCH] Use CircuitShots --- qermit/spam/full_spam_correction.py | 12 +++++++----- tests/full_spam_test.py | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/qermit/spam/full_spam_correction.py b/qermit/spam/full_spam_correction.py index f567705d..e2969b94 100644 --- a/qermit/spam/full_spam_correction.py +++ b/qermit/spam/full_spam_correction.py @@ -59,6 +59,8 @@ def task( == qubit_subsets ): return (wire, [], []) + + print("qubit_subsets: ", qubit_subsets) process_circuit = Circuit( len([qb for subset in qubit_subsets for qb in subset]) @@ -179,7 +181,7 @@ def get_mid_circuit_measure_map( def gen_get_bit_maps_task() -> MitTask: """ - Returns a task that takes a list of circuits and returns the circuits, and a map betwen + Returns a task that takes a list of circuits and returns the circuits, and a map between each circuit bit and the qubit it is measured on. """ @@ -193,12 +195,12 @@ def task( """ bq_maps = [] for c in circuit_shots: - qb_map = c[0].qubit_to_bit_map - # if condition met, implies that mid circuit measurement has ocurred and not accounted for + qb_map = c.Circuit.qubit_to_bit_map + # if condition met, implies that mid circuit measurement has occurred and not accounted for # in this case, iterate through circuit commands to get Qubits for all Bits - if len(qb_map) != len(c[0].bits): + if len(qb_map) != len(c.Circuit.bits): bq_maps.append( - (qb_map, get_mid_circuit_measure_map(c[0], set(qb_map.values()))) + (qb_map, get_mid_circuit_measure_map(c.Circuit, set(qb_map.values()))) ) else: # else, just invert map for later correction diff --git a/tests/full_spam_test.py b/tests/full_spam_test.py index ac28afd6..9719324b 100644 --- a/tests/full_spam_test.py +++ b/tests/full_spam_test.py @@ -17,6 +17,8 @@ from pytket import Bit, Circuit, Qubit from pytket.extensions.qiskit import AerBackend # type: ignore +from qermit import CircuitShots + from qermit.spam import ( # type: ignore CorrectionMethod, ) @@ -90,7 +92,7 @@ def test_gen_get_bit_maps_task(): c0 = Circuit(3).CX(0, 1).X(2).measure_all() c1 = Circuit(2, 2).X(0).Measure(0, 0).X(1).SWAP(0, 1).Measure(0, 1) - wire = [(c0, 10), (c1, 50)] + wire = [CircuitShots(c0, 10), CircuitShots(c1, 50)] res = task([wire]) assert len(res) == 2 assert res[0] == wire