Skip to content

Commit

Permalink
Merge pull request #525 from quantumlib/qsim-pyopt
Browse files Browse the repository at this point in the history
Reduce isinstance calls
  • Loading branch information
95-martin-orion authored Apr 18, 2022
2 parents 9e75230 + 19e0594 commit 71d2a4f
Showing 1 changed file with 187 additions and 111 deletions.
298 changes: 187 additions & 111 deletions qsimcirq/qsim_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,103 +34,183 @@
]


def _cirq_gate_kind(gate: cirq.ops.Gate):
if isinstance(gate, cirq.ops.ControlledGate):
return _cirq_gate_kind(gate.sub_gate)
if isinstance(gate, cirq.ops.identity.IdentityGate):
# Identity gates will decompose to no-ops.
pass
if isinstance(gate, cirq.ops.XPowGate):
# cirq.rx also uses this path.
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kX
return qsim.kXPowGate
if isinstance(gate, cirq.ops.YPowGate):
# cirq.ry also uses this path.
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kY
return qsim.kYPowGate
if isinstance(gate, cirq.ops.ZPowGate):
# cirq.rz also uses this path.
if gate.global_shift == 0:
if gate.exponent == 1:
return qsim.kZ
if gate.exponent == 0.5:
return qsim.kS
if gate.exponent == 0.25:
return qsim.kT
return qsim.kZPowGate
if isinstance(gate, cirq.ops.HPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kH
return qsim.kHPowGate
if isinstance(gate, cirq.ops.CZPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kCZ
return qsim.kCZPowGate
if isinstance(gate, cirq.ops.CXPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kCX
return qsim.kCXPowGate
if isinstance(gate, cirq.ops.PhasedXPowGate):
return qsim.kPhasedXPowGate
if isinstance(gate, cirq.ops.PhasedXZGate):
return qsim.kPhasedXZGate
if isinstance(gate, cirq.ops.XXPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kXX
return qsim.kXXPowGate
if isinstance(gate, cirq.ops.YYPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kYY
return qsim.kYYPowGate
if isinstance(gate, cirq.ops.ZZPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kZZ
return qsim.kZZPowGate
if isinstance(gate, cirq.ops.SwapPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kSWAP
return qsim.kSwapPowGate
if isinstance(gate, cirq.ops.ISwapPowGate):
# cirq.riswap also uses this path.
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kISWAP
return qsim.kISwapPowGate
if isinstance(gate, cirq.ops.PhasedISwapPowGate):
# cirq.givens also uses this path.
return qsim.kPhasedISwapPowGate
if isinstance(gate, cirq.ops.FSimGate):
return qsim.kFSimGate
if isinstance(gate, cirq.ops.TwoQubitDiagonalGate):
return qsim.kTwoQubitDiagonalGate
if isinstance(gate, cirq.ops.ThreeQubitDiagonalGate):
return qsim.kThreeQubitDiagonalGate
if isinstance(gate, cirq.ops.CCZPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kCCZ
return qsim.kCCZPowGate
if isinstance(gate, cirq.ops.CCXPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kCCX
return qsim.kCCXPowGate
if isinstance(gate, cirq.ops.CSwapGate):
return qsim.kCSwapGate
if isinstance(gate, cirq.ops.MatrixGate):
if gate.num_qubits() <= 6:
return qsim.kMatrixGate
raise NotImplementedError(
f"Received matrix on {gate.num_qubits()} qubits; "
+ "only up to 6-qubit gates are supported."
)
if isinstance(gate, cirq.ops.MeasurementGate):
# needed to inherit SimulatesSamples in sims
return qsim.kMeasurement
def _translate_ControlledGate(gate: cirq.ControlledGate):
return _cirq_gate_kind(gate.sub_gate)


def _translate_XPowGate(gate: cirq.XPowGate):
# cirq.rx also uses this path.
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kX
return qsim.kXPowGate


def _translate_YPowGate(gate: cirq.YPowGate):
# cirq.ry also uses this path.
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kY
return qsim.kYPowGate


def _translate_ZPowGate(gate: cirq.ZPowGate):
# cirq.rz also uses this path.
if gate.global_shift == 0:
if gate.exponent == 1:
return qsim.kZ
if gate.exponent == 0.5:
return qsim.kS
if gate.exponent == 0.25:
return qsim.kT
return qsim.kZPowGate


def _translate_HPowGate(gate: cirq.HPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kH
return qsim.kHPowGate


def _translate_CZPowGate(gate: cirq.CZPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kCZ
return qsim.kCZPowGate


def _translate_CXPowGate(gate: cirq.CXPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kCX
return qsim.kCXPowGate


def _translate_PhasedXPowGate(gate: cirq.PhasedXPowGate):
return qsim.kPhasedXPowGate


def _translate_PhasedXZGate(gate: cirq.PhasedXZGate):
return qsim.kPhasedXZGate


def _translate_XXPowGate(gate: cirq.XXPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kXX
return qsim.kXXPowGate


def _translate_YYPowGate(gate: cirq.YYPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kYY
return qsim.kYYPowGate


def _translate_ZZPowGate(gate: cirq.ZZPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kZZ
return qsim.kZZPowGate


def _translate_SwapPowGate(gate: cirq.SwapPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kSWAP
return qsim.kSwapPowGate


def _translate_ISwapPowGate(gate: cirq.ISwapPowGate):
# cirq.riswap also uses this path.
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kISWAP
return qsim.kISwapPowGate


def _translate_PhasedISwapPowGate(gate: cirq.PhasedISwapPowGate):
# cirq.givens also uses this path.
return qsim.kPhasedISwapPowGate


def _translate_FSimGate(gate: cirq.FSimGate):
return qsim.kFSimGate


def _translate_TwoQubitDiagonalGate(gate: cirq.TwoQubitDiagonalGate):
return qsim.kTwoQubitDiagonalGate


def _translate_ThreeQubitDiagonalGate(gate: cirq.ThreeQubitDiagonalGate):
return qsim.kThreeQubitDiagonalGate


def _translate_CCZPowGate(gate: cirq.CCZPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kCCZ
return qsim.kCCZPowGate


def _translate_CCXPowGate(gate: cirq.CCXPowGate):
if gate.exponent == 1 and gate.global_shift == 0:
return qsim.kCCX
return qsim.kCCXPowGate


def _translate_CSwapGate(gate: cirq.CSwapGate):
return qsim.kCSwapGate


def _translate_MatrixGate(gate: cirq.MatrixGate):
if gate.num_qubits() <= 6:
return qsim.kMatrixGate
raise NotImplementedError(
f"Received matrix on {gate.num_qubits()} qubits; "
+ "only up to 6-qubit gates are supported."
)


def _translate_MeasurementGate(gate: cirq.MeasurementGate):
# needed to inherit SimulatesSamples in sims
return qsim.kMeasurement


TYPE_TRANSLATOR = {
cirq.ControlledGate: _translate_ControlledGate,
cirq.XPowGate: _translate_XPowGate,
cirq.YPowGate: _translate_YPowGate,
cirq.ZPowGate: _translate_ZPowGate,
cirq.HPowGate: _translate_HPowGate,
cirq.CZPowGate: _translate_CZPowGate,
cirq.CXPowGate: _translate_CXPowGate,
cirq.PhasedXPowGate: _translate_PhasedXPowGate,
cirq.PhasedXZGate: _translate_PhasedXZGate,
cirq.XXPowGate: _translate_XXPowGate,
cirq.YYPowGate: _translate_YYPowGate,
cirq.ZZPowGate: _translate_ZZPowGate,
cirq.SwapPowGate: _translate_SwapPowGate,
cirq.ISwapPowGate: _translate_ISwapPowGate,
cirq.PhasedISwapPowGate: _translate_PhasedISwapPowGate,
cirq.FSimGate: _translate_FSimGate,
cirq.TwoQubitDiagonalGate: _translate_TwoQubitDiagonalGate,
cirq.ThreeQubitDiagonalGate: _translate_ThreeQubitDiagonalGate,
cirq.CCZPowGate: _translate_CCZPowGate,
cirq.CCXPowGate: _translate_CCXPowGate,
cirq.CSwapGate: _translate_CSwapGate,
cirq.MatrixGate: _translate_MatrixGate,
cirq.MeasurementGate: _translate_MeasurementGate,
}


def _cirq_gate_kind(gate: cirq.Gate):
for gate_type in type(gate).mro():
translator = TYPE_TRANSLATOR.get(gate_type, None)
if translator is not None:
return translator(gate)
# Unrecognized gates will be decomposed.
return None


def _control_details(gate: cirq.ops.ControlledGate, qubits):
def _has_cirq_gate_kind(op: cirq.Operation):
if isinstance(op, cirq.ControlledOperation):
return _has_cirq_gate_kind(op.sub_operation)
return any(t in TYPE_TRANSLATOR for t in type(op.gate).mro())


def _control_details(gate: cirq.ControlledGate, qubits):
control_qubits = []
control_values = []
# TODO: support qudit control
Expand Down Expand Up @@ -169,7 +249,7 @@ def add_op_to_opstring(
if len(qsim_op.qubits) != 1:
raise ValueError(f"OpString ops should have 1 qubit; got {len(qsim_op.qubits)}")

is_controlled = isinstance(qsim_gate, cirq.ops.ControlledGate)
is_controlled = isinstance(qsim_gate, cirq.ControlledGate)
if is_controlled:
raise ValueError(f"OpString ops should not be controlled.")

Expand All @@ -189,7 +269,7 @@ def add_op_to_circuit(
qubits = [qubit_to_index_dict[q] for q in qsim_op.qubits]

qsim_qubits = qubits
is_controlled = isinstance(qsim_gate, cirq.ops.ControlledGate)
is_controlled = isinstance(qsim_gate, cirq.ControlledGate)
if is_controlled:
control_qubits, control_values = _control_details(qsim_gate, qubits)
if control_qubits is None:
Expand Down Expand Up @@ -276,7 +356,7 @@ def _resolve_parameters_(
return QSimCircuit(cirq.resolve_parameters(super(), param_resolver, recursive))

def translate_cirq_to_qsim(
self, qubit_order: cirq.ops.QubitOrderOrList = cirq.ops.QubitOrder.DEFAULT
self, qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT
) -> qsim.Circuit:
"""
Translates this Cirq circuit to the qsim representation.
Expand All @@ -286,31 +366,30 @@ def translate_cirq_to_qsim(
"""

qsim_circuit = qsim.Circuit()
ordered_qubits = cirq.ops.QubitOrder.as_qubit_order(qubit_order).order_for(
ordered_qubits = cirq.QubitOrder.as_qubit_order(qubit_order).order_for(
self.all_qubits()
)
qsim_circuit.num_qubits = len(ordered_qubits)

# qsim numbers qubits in reverse order from cirq
ordered_qubits = list(reversed(ordered_qubits))

def has_qsim_kind(op: cirq.ops.GateOperation):
return _cirq_gate_kind(op.gate) != None

def to_matrix(op: cirq.ops.GateOperation):
def to_matrix(op: cirq.GateOperation):
mat = cirq.unitary(op.gate, None)
if mat is None:
return NotImplemented

return cirq.ops.MatrixGate(mat).on(*op.qubits)
return cirq.MatrixGate(mat).on(*op.qubits)

qubit_to_index_dict = {q: i for i, q in enumerate(ordered_qubits)}
time_offset = 0
gate_count = 0
moment_indices = []
for moment in self:
ops_by_gate = [
cirq.decompose(op, fallback_decomposer=to_matrix, keep=has_qsim_kind)
cirq.decompose(
op, fallback_decomposer=to_matrix, keep=_has_cirq_gate_kind
)
for op in moment
]
moment_length = max((len(gate_ops) for gate_ops in ops_by_gate), default=0)
Expand All @@ -330,7 +409,7 @@ def to_matrix(op: cirq.ops.GateOperation):
return qsim_circuit, moment_indices

def translate_cirq_to_qtrajectory(
self, qubit_order: cirq.ops.QubitOrderOrList = cirq.ops.QubitOrder.DEFAULT
self, qubit_order: cirq.QubitOrderOrList = cirq.QubitOrder.DEFAULT
) -> qsim.NoisyCircuit:
"""
Translates this noisy Cirq circuit to the qsim representation.
Expand All @@ -339,7 +418,7 @@ def translate_cirq_to_qtrajectory(
gate indices)
"""
qsim_ncircuit = qsim.NoisyCircuit()
ordered_qubits = cirq.ops.QubitOrder.as_qubit_order(qubit_order).order_for(
ordered_qubits = cirq.QubitOrder.as_qubit_order(qubit_order).order_for(
self.all_qubits()
)

Expand All @@ -348,15 +427,12 @@ def translate_cirq_to_qtrajectory(

qsim_ncircuit.num_qubits = len(ordered_qubits)

def has_qsim_kind(op: cirq.ops.GateOperation):
return _cirq_gate_kind(op.gate) != None

def to_matrix(op: cirq.ops.GateOperation):
def to_matrix(op: cirq.GateOperation):
mat = cirq.unitary(op.gate, None)
if mat is None:
return NotImplemented

return cirq.ops.MatrixGate(mat).on(*op.qubits)
return cirq.MatrixGate(mat).on(*op.qubits)

qubit_to_index_dict = {q: i for i, q in enumerate(ordered_qubits)}
time_offset = 0
Expand All @@ -371,7 +447,7 @@ def to_matrix(op: cirq.ops.GateOperation):
for qsim_op in moment:
if cirq.has_unitary(qsim_op) or cirq.is_measurement(qsim_op):
oplist = cirq.decompose(
qsim_op, fallback_decomposer=to_matrix, keep=has_qsim_kind
qsim_op, fallback_decomposer=to_matrix, keep=_has_cirq_gate_kind
)
ops_by_gate.append(oplist)
moment_length = max(moment_length, len(oplist))
Expand Down

0 comments on commit 71d2a4f

Please sign in to comment.