diff --git a/qsimcirq/qsim_circuit.py b/qsimcirq/qsim_circuit.py index 21a2094e..fd7d9450 100644 --- a/qsimcirq/qsim_circuit.py +++ b/qsimcirq/qsim_circuit.py @@ -34,103 +34,183 @@ ] -def _cirq_gate_kind(gate: cirq.ops.Gate): - if isinstance(gate, cirq.ops.ControlledGate): - return _cirq_gate_kind(gate.sub_gate) - if isinstance(gate, cirq.ops.identity.IdentityGate): - # Identity gates will decompose to no-ops. - pass - if isinstance(gate, cirq.ops.XPowGate): - # cirq.rx also uses this path. - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kX - return qsim.kXPowGate - if isinstance(gate, cirq.ops.YPowGate): - # cirq.ry also uses this path. - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kY - return qsim.kYPowGate - if isinstance(gate, cirq.ops.ZPowGate): - # cirq.rz also uses this path. - if gate.global_shift == 0: - if gate.exponent == 1: - return qsim.kZ - if gate.exponent == 0.5: - return qsim.kS - if gate.exponent == 0.25: - return qsim.kT - return qsim.kZPowGate - if isinstance(gate, cirq.ops.HPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kH - return qsim.kHPowGate - if isinstance(gate, cirq.ops.CZPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kCZ - return qsim.kCZPowGate - if isinstance(gate, cirq.ops.CXPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kCX - return qsim.kCXPowGate - if isinstance(gate, cirq.ops.PhasedXPowGate): - return qsim.kPhasedXPowGate - if isinstance(gate, cirq.ops.PhasedXZGate): - return qsim.kPhasedXZGate - if isinstance(gate, cirq.ops.XXPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kXX - return qsim.kXXPowGate - if isinstance(gate, cirq.ops.YYPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kYY - return qsim.kYYPowGate - if isinstance(gate, cirq.ops.ZZPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kZZ - return qsim.kZZPowGate - if isinstance(gate, cirq.ops.SwapPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kSWAP - return qsim.kSwapPowGate - if isinstance(gate, cirq.ops.ISwapPowGate): - # cirq.riswap also uses this path. - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kISWAP - return qsim.kISwapPowGate - if isinstance(gate, cirq.ops.PhasedISwapPowGate): - # cirq.givens also uses this path. - return qsim.kPhasedISwapPowGate - if isinstance(gate, cirq.ops.FSimGate): - return qsim.kFSimGate - if isinstance(gate, cirq.ops.TwoQubitDiagonalGate): - return qsim.kTwoQubitDiagonalGate - if isinstance(gate, cirq.ops.ThreeQubitDiagonalGate): - return qsim.kThreeQubitDiagonalGate - if isinstance(gate, cirq.ops.CCZPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kCCZ - return qsim.kCCZPowGate - if isinstance(gate, cirq.ops.CCXPowGate): - if gate.exponent == 1 and gate.global_shift == 0: - return qsim.kCCX - return qsim.kCCXPowGate - if isinstance(gate, cirq.ops.CSwapGate): - return qsim.kCSwapGate - if isinstance(gate, cirq.ops.MatrixGate): - if gate.num_qubits() <= 6: - return qsim.kMatrixGate - raise NotImplementedError( - f"Received matrix on {gate.num_qubits()} qubits; " - + "only up to 6-qubit gates are supported." - ) - if isinstance(gate, cirq.ops.MeasurementGate): - # needed to inherit SimulatesSamples in sims - return qsim.kMeasurement +def _translate_ControlledGate(gate: cirq.ControlledGate): + return _cirq_gate_kind(gate.sub_gate) + + +def _translate_XPowGate(gate: cirq.XPowGate): + # cirq.rx also uses this path. + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kX + return qsim.kXPowGate + + +def _translate_YPowGate(gate: cirq.YPowGate): + # cirq.ry also uses this path. + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kY + return qsim.kYPowGate + + +def _translate_ZPowGate(gate: cirq.ZPowGate): + # cirq.rz also uses this path. + if gate.global_shift == 0: + if gate.exponent == 1: + return qsim.kZ + if gate.exponent == 0.5: + return qsim.kS + if gate.exponent == 0.25: + return qsim.kT + return qsim.kZPowGate + + +def _translate_HPowGate(gate: cirq.HPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kH + return qsim.kHPowGate + + +def _translate_CZPowGate(gate: cirq.CZPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kCZ + return qsim.kCZPowGate + + +def _translate_CXPowGate(gate: cirq.CXPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kCX + return qsim.kCXPowGate + + +def _translate_PhasedXPowGate(gate: cirq.PhasedXPowGate): + return qsim.kPhasedXPowGate + + +def _translate_PhasedXZGate(gate: cirq.PhasedXZGate): + return qsim.kPhasedXZGate + + +def _translate_XXPowGate(gate: cirq.XXPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kXX + return qsim.kXXPowGate + + +def _translate_YYPowGate(gate: cirq.YYPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kYY + return qsim.kYYPowGate + + +def _translate_ZZPowGate(gate: cirq.ZZPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kZZ + return qsim.kZZPowGate + + +def _translate_SwapPowGate(gate: cirq.SwapPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kSWAP + return qsim.kSwapPowGate + + +def _translate_ISwapPowGate(gate: cirq.ISwapPowGate): + # cirq.riswap also uses this path. + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kISWAP + return qsim.kISwapPowGate + + +def _translate_PhasedISwapPowGate(gate: cirq.PhasedISwapPowGate): + # cirq.givens also uses this path. + return qsim.kPhasedISwapPowGate + + +def _translate_FSimGate(gate: cirq.FSimGate): + return qsim.kFSimGate + + +def _translate_TwoQubitDiagonalGate(gate: cirq.TwoQubitDiagonalGate): + return qsim.kTwoQubitDiagonalGate + + +def _translate_ThreeQubitDiagonalGate(gate: cirq.ThreeQubitDiagonalGate): + return qsim.kThreeQubitDiagonalGate + + +def _translate_CCZPowGate(gate: cirq.CCZPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kCCZ + return qsim.kCCZPowGate + + +def _translate_CCXPowGate(gate: cirq.CCXPowGate): + if gate.exponent == 1 and gate.global_shift == 0: + return qsim.kCCX + return qsim.kCCXPowGate + + +def _translate_CSwapGate(gate: cirq.CSwapGate): + return qsim.kCSwapGate + + +def _translate_MatrixGate(gate: cirq.MatrixGate): + if gate.num_qubits() <= 6: + return qsim.kMatrixGate + raise NotImplementedError( + f"Received matrix on {gate.num_qubits()} qubits; " + + "only up to 6-qubit gates are supported." + ) + + +def _translate_MeasurementGate(gate: cirq.MeasurementGate): + # needed to inherit SimulatesSamples in sims + return qsim.kMeasurement + + +TYPE_TRANSLATOR = { + cirq.ControlledGate: _translate_ControlledGate, + cirq.XPowGate: _translate_XPowGate, + cirq.YPowGate: _translate_YPowGate, + cirq.ZPowGate: _translate_ZPowGate, + cirq.HPowGate: _translate_HPowGate, + cirq.CZPowGate: _translate_CZPowGate, + cirq.CXPowGate: _translate_CXPowGate, + cirq.PhasedXPowGate: _translate_PhasedXPowGate, + cirq.PhasedXZGate: _translate_PhasedXZGate, + cirq.XXPowGate: _translate_XXPowGate, + cirq.YYPowGate: _translate_YYPowGate, + cirq.ZZPowGate: _translate_ZZPowGate, + cirq.SwapPowGate: _translate_SwapPowGate, + cirq.ISwapPowGate: _translate_ISwapPowGate, + cirq.PhasedISwapPowGate: _translate_PhasedISwapPowGate, + cirq.FSimGate: _translate_FSimGate, + cirq.TwoQubitDiagonalGate: _translate_TwoQubitDiagonalGate, + cirq.ThreeQubitDiagonalGate: _translate_ThreeQubitDiagonalGate, + cirq.CCZPowGate: _translate_CCZPowGate, + cirq.CCXPowGate: _translate_CCXPowGate, + cirq.CSwapGate: _translate_CSwapGate, + cirq.MatrixGate: _translate_MatrixGate, + cirq.MeasurementGate: _translate_MeasurementGate, +} + + +def _cirq_gate_kind(gate: cirq.Gate): + for gate_type in type(gate).mro(): + translator = TYPE_TRANSLATOR.get(gate_type, None) + if translator is not None: + return translator(gate) # Unrecognized gates will be decomposed. return None -def _control_details(gate: cirq.ops.ControlledGate, qubits): +def _has_cirq_gate_kind(op: cirq.Operation): + if isinstance(op, cirq.ControlledOperation): + return _has_cirq_gate_kind(op.sub_operation) + return any(t in TYPE_TRANSLATOR for t in type(op.gate).mro()) + + +def _control_details(gate: cirq.ControlledGate, qubits): control_qubits = [] control_values = [] # TODO: support qudit control @@ -169,7 +249,7 @@ def add_op_to_opstring( if len(qsim_op.qubits) != 1: raise ValueError(f"OpString ops should have 1 qubit; got {len(qsim_op.qubits)}") - is_controlled = isinstance(qsim_gate, cirq.ops.ControlledGate) + is_controlled = isinstance(qsim_gate, cirq.ControlledGate) if is_controlled: raise ValueError(f"OpString ops should not be controlled.") @@ -189,7 +269,7 @@ def add_op_to_circuit( qubits = [qubit_to_index_dict[q] for q in qsim_op.qubits] qsim_qubits = qubits - is_controlled = isinstance(qsim_gate, cirq.ops.ControlledGate) + is_controlled = isinstance(qsim_gate, cirq.ControlledGate) if is_controlled: control_qubits, control_values = _control_details(qsim_gate, qubits) if control_qubits is None: @@ -276,7 +356,7 @@ def _resolve_parameters_( return QSimCircuit(cirq.resolve_parameters(super(), param_resolver, recursive)) def translate_cirq_to_qsim( - self, qubit_order: cirq.ops.QubitOrderOrList = cirq.ops.QubitOrder.DEFAULT + self, qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT ) -> qsim.Circuit: """ Translates this Cirq circuit to the qsim representation. @@ -286,7 +366,7 @@ def translate_cirq_to_qsim( """ qsim_circuit = qsim.Circuit() - ordered_qubits = cirq.ops.QubitOrder.as_qubit_order(qubit_order).order_for( + ordered_qubits = cirq.QubitOrder.as_qubit_order(qubit_order).order_for( self.all_qubits() ) qsim_circuit.num_qubits = len(ordered_qubits) @@ -294,15 +374,12 @@ def translate_cirq_to_qsim( # qsim numbers qubits in reverse order from cirq ordered_qubits = list(reversed(ordered_qubits)) - def has_qsim_kind(op: cirq.ops.GateOperation): - return _cirq_gate_kind(op.gate) != None - - def to_matrix(op: cirq.ops.GateOperation): + def to_matrix(op: cirq.GateOperation): mat = cirq.unitary(op.gate, None) if mat is None: return NotImplemented - return cirq.ops.MatrixGate(mat).on(*op.qubits) + return cirq.MatrixGate(mat).on(*op.qubits) qubit_to_index_dict = {q: i for i, q in enumerate(ordered_qubits)} time_offset = 0 @@ -310,7 +387,9 @@ def to_matrix(op: cirq.ops.GateOperation): moment_indices = [] for moment in self: ops_by_gate = [ - cirq.decompose(op, fallback_decomposer=to_matrix, keep=has_qsim_kind) + cirq.decompose( + op, fallback_decomposer=to_matrix, keep=_has_cirq_gate_kind + ) for op in moment ] moment_length = max((len(gate_ops) for gate_ops in ops_by_gate), default=0) @@ -330,7 +409,7 @@ def to_matrix(op: cirq.ops.GateOperation): return qsim_circuit, moment_indices def translate_cirq_to_qtrajectory( - self, qubit_order: cirq.ops.QubitOrderOrList = cirq.ops.QubitOrder.DEFAULT + self, qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT ) -> qsim.NoisyCircuit: """ Translates this noisy Cirq circuit to the qsim representation. @@ -339,7 +418,7 @@ def translate_cirq_to_qtrajectory( gate indices) """ qsim_ncircuit = qsim.NoisyCircuit() - ordered_qubits = cirq.ops.QubitOrder.as_qubit_order(qubit_order).order_for( + ordered_qubits = cirq.QubitOrder.as_qubit_order(qubit_order).order_for( self.all_qubits() ) @@ -348,15 +427,12 @@ def translate_cirq_to_qtrajectory( qsim_ncircuit.num_qubits = len(ordered_qubits) - def has_qsim_kind(op: cirq.ops.GateOperation): - return _cirq_gate_kind(op.gate) != None - - def to_matrix(op: cirq.ops.GateOperation): + def to_matrix(op: cirq.GateOperation): mat = cirq.unitary(op.gate, None) if mat is None: return NotImplemented - return cirq.ops.MatrixGate(mat).on(*op.qubits) + return cirq.MatrixGate(mat).on(*op.qubits) qubit_to_index_dict = {q: i for i, q in enumerate(ordered_qubits)} time_offset = 0 @@ -371,7 +447,7 @@ def to_matrix(op: cirq.ops.GateOperation): for qsim_op in moment: if cirq.has_unitary(qsim_op) or cirq.is_measurement(qsim_op): oplist = cirq.decompose( - qsim_op, fallback_decomposer=to_matrix, keep=has_qsim_kind + qsim_op, fallback_decomposer=to_matrix, keep=_has_cirq_gate_kind ) ops_by_gate.append(oplist) moment_length = max(moment_length, len(oplist))