diff --git a/qualtran/bloqs/data_loading/qrom_adjoint.py b/qualtran/bloqs/data_loading/qrom_adjoint.py index e07385d24..54e9230b8 100644 --- a/qualtran/bloqs/data_loading/qrom_adjoint.py +++ b/qualtran/bloqs/data_loading/qrom_adjoint.py @@ -18,14 +18,25 @@ @dataclasses.dataclass(frozen=True) class QROMAdjCondition(Condition): - dx: int - + key: cirq.MeasurementKey + dx: List[int] + + def replace_key(self, current: cirq.MeasurementKey, replacement: cirq.MeasurementKey): + return QROMAdjCondition(replacement, self.dx) if self.key == current else self + + def resolve(self, classical_data: cirq.ClassicalDataStoreReader) -> bool: + y = classical_data.get_digits(self.key) + active = False + for yi, dxi in zip(y, self.dx): + active = not active if yi * dxi == 1 else active + return active @attrs.define class QROMWithClassicalControls(QROM): QROM_bloq: QROM = field(default=None) + mz_key: str = field(default="target_mzs") def calc_dx(self, x): bitstring = [] @@ -33,7 +44,7 @@ def calc_dx(self, x): for i in range(len(self.QROM_bloq.target_bitsizes)): bitsize = self.QROM_bloq.target_bitsizes[i] data = self.QROM_bloq.data[i][x[x_start:x_start + bitsize]] - bitstring.append(iter_bits(data, bitsize)) + bitstring.extend(iter_bits(data, bitsize)) return bitstring @@ -54,6 +65,7 @@ def nth_operation( for i in range(len(target)): target_bits = iter_bits(i, N) dx = self.calc_dx(list(itertools.chain(selection_bits, target_bits))) + yield cirq.X(target[i]).with_classical_controls(QROMAdjCondition(cirq.MeasurementKey(self.mz_key), dx)) @attrs.frozen