diff --git a/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.cpp b/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.cpp index 687341d36..b8a92d15e 100644 --- a/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.cpp +++ b/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.cpp @@ -15,6 +15,7 @@ #include "xacc_service.hpp" #include "IonTrapTwoQubitPass.hpp" #include "Accelerator.hpp" +#include "IonTrapTwoQubitPassVisitor.hpp" namespace xacc { namespace quantum { @@ -23,6 +24,10 @@ namespace quantum { // Two-qubit decompositions // +IonTrapTwoQubitPass::IonTrapTwoQubitPass() +{ +} + std::pair IonTrapTwoQubitPass::findMSPhases(IonTrapMSPhaseMap *msPhases, InstPtr cnot) { std::size_t leftIdx = std::min(cnot->bits()[0], cnot->bits()[1]); std::size_t rightIdx = std::max(cnot->bits()[0], cnot->bits()[1]); @@ -53,10 +58,15 @@ void IonTrapTwoQubitPass::apply(std::shared_ptr program, logTransCallback = options.get("log-trans-cb"); } - auto gateRegistry = xacc::getService("quantum"); + auto _twoQubitPassVisitor = std::make_shared(); iontrapFlattenComposite(program); + HeterogeneousMap paramsMap{std::make_pair("composite", program), + std::make_pair("options", options)}; + + _twoQubitPassVisitor->_paramsMap = paramsMap; + for (std::size_t instIdx = 0; instIdx < program->nInstructions();) { InstPtr inst = program->getInstruction(instIdx); if (!inst->isEnabled()) { @@ -64,123 +74,19 @@ void IonTrapTwoQubitPass::apply(std::shared_ptr program, continue; } - if (inst->name() == "CNOT") { - auto [controlMSPhase, targetMSPhase] = findMSPhases(msPhases, inst); - InstPtr ry1 = gateRegistry->createInstruction("Ry", {inst->bits()[0]}, {-M_PI/2.0}); - InstPtr xx = gateRegistry->createInstruction("XX", inst->bits(), {M_PI/4.0}); - InstPtr ry2 = gateRegistry->createInstruction("Ry", {inst->bits()[0]}, {M_PI/2.0}); - InstPtr rz = gateRegistry->createInstruction("Rz", {inst->bits()[0]}, {M_PI/2.0}); - InstPtr rx = gateRegistry->createInstruction("Rx", {inst->bits()[1]}, {M_PI/2.0}); - - std::size_t i = instIdx; - program->insertInstruction(i++, ry1); - // TODO: Note that this is kind of incorrect: really, the combination of these Rz gates - // and an MS gate is actually an XX gate (see https://doi.org/10.1088/1367-2630/18/2/023048) - // but we are surrounding an XX instruction with Rz instructions. But this will - // work for now - if (controlMSPhase) { - InstPtr msRz1 = gateRegistry->createInstruction("Rz", {inst->bits()[0]}, {controlMSPhase}); - program->insertInstruction(i++, msRz1); - } - if (targetMSPhase) { - InstPtr msRz2 = gateRegistry->createInstruction("Rz", {inst->bits()[1]}, {targetMSPhase}); - program->insertInstruction(i++, msRz2); - } - program->insertInstruction(i++, xx); - if (controlMSPhase) { - InstPtr msRz3 = gateRegistry->createInstruction("Rz", {inst->bits()[0]}, {-controlMSPhase}); - program->insertInstruction(i++, msRz3); - } - if (targetMSPhase) { - InstPtr msRz4 = gateRegistry->createInstruction("Rz", {inst->bits()[1]}, {-targetMSPhase}); - program->insertInstruction(i++, msRz4); - } - program->insertInstruction(i++, ry2); - program->insertInstruction(i++, rz); - program->insertInstruction(i++, rx); - - if (logTransCallback) { - std::vector newInsts; - for (std::size_t j = instIdx; j < i; j++) { - newInsts.push_back(program->getInstruction(j)); - } - logTransCallback({inst}, newInsts); - } - } else if (inst->name() == "CH") { - InstPtr s = gateRegistry->createInstruction("S", {inst->bits()[1]}); - InstPtr h = gateRegistry->createInstruction("H", {inst->bits()[1]}); - InstPtr t = gateRegistry->createInstruction("T", {inst->bits()[1]}); - InstPtr cx = gateRegistry->createInstruction("CNOT", inst->bits()); - InstPtr tdg = gateRegistry->createInstruction("Tdg", {inst->bits()[1]}); - InstPtr h2 = gateRegistry->createInstruction("H", {inst->bits()[1]}); - InstPtr sdg = gateRegistry->createInstruction("Sdg", {inst->bits()[1]}); - - program->insertInstruction(instIdx, s); - program->insertInstruction(instIdx+1, h); - program->insertInstruction(instIdx+2, t); - program->insertInstruction(instIdx+3, cx); - program->insertInstruction(instIdx+4, tdg); - program->insertInstruction(instIdx+5, h2); - program->insertInstruction(instIdx+6, sdg); - - if (logTransCallback) { - logTransCallback({program->getInstruction(instIdx+7)}, - {program->getInstruction(instIdx), - program->getInstruction(instIdx+1), - program->getInstruction(instIdx+2), - program->getInstruction(instIdx+3), - program->getInstruction(instIdx+4), - program->getInstruction(instIdx+5), - program->getInstruction(instIdx+6)}); - } - } else if (inst->name() == "CY") { - InstPtr sdg = gateRegistry->createInstruction("Sdg", {inst->bits()[1]}); - InstPtr cx = gateRegistry->createInstruction("CNOT", inst->bits()); - InstPtr s = gateRegistry->createInstruction("S", {inst->bits()[1]}); - - program->insertInstruction(instIdx, sdg); - program->insertInstruction(instIdx+1, cx); - program->insertInstruction(instIdx+2, s); - - if (logTransCallback) { - logTransCallback({program->getInstruction(instIdx+3)}, - {program->getInstruction(instIdx), - program->getInstruction(instIdx+1), - program->getInstruction(instIdx+2)}); - } - } else if (inst->name() == "CZ") { - InstPtr h = gateRegistry->createInstruction("H", {inst->bits()[1]}); - InstPtr cx = gateRegistry->createInstruction("CNOT", inst->bits()); - InstPtr h2 = gateRegistry->createInstruction("H", {inst->bits()[1]}); - - program->insertInstruction(instIdx, h); - program->insertInstruction(instIdx+1, cx); - program->insertInstruction(instIdx+2, h2); - - if (logTransCallback) { - logTransCallback({program->getInstruction(instIdx+3)}, - {program->getInstruction(instIdx), - program->getInstruction(instIdx+1), - program->getInstruction(instIdx+2)}); - } - } else if (inst->name() == "Swap") { - InstPtr cx1 = gateRegistry->createInstruction("CNOT", inst->bits()); - InstPtr cx2 = gateRegistry->createInstruction("CNOT", {inst->bits()[1], inst->bits()[0]}); - InstPtr cx3 = gateRegistry->createInstruction("CNOT", inst->bits()); - - program->insertInstruction(instIdx, cx1); - program->insertInstruction(instIdx+1, cx2); - program->insertInstruction(instIdx+2, cx3); - - if (logTransCallback) { - logTransCallback({program->getInstruction(instIdx+3)}, - {program->getInstruction(instIdx), - program->getInstruction(instIdx+1), - program->getInstruction(instIdx+2)}); - } - } else { + _twoQubitPassVisitor->initializeInstructionVisitor(instIdx); + + inst->attachMetadata({{"composite", program}, + {"options", options}}); + + //std::cout << "instruction index is " << instIdx << ", name is " << inst->name() << std::endl; + inst->accept(_twoQubitPassVisitor); + //std::cout << "instruction index is " << instIdx << ", name is " << inst->name() << std::endl; + + if (!_twoQubitPassVisitor->instructionVisited()) + { instIdx++; - continue; + continue; } inst->disable(); diff --git a/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.hpp b/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.hpp index 5bc6a45c4..21648f4ee 100644 --- a/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.hpp +++ b/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.hpp @@ -24,10 +24,11 @@ namespace xacc { namespace quantum { typedef std::map, std::pair> IonTrapMSPhaseMap; +class IonTrapTwoQubitPassVisitor; class IonTrapTwoQubitPass : public IRTransformation { public: - IonTrapTwoQubitPass() {} + IonTrapTwoQubitPass(); void apply(std::shared_ptr program, const std::shared_ptr accelerator, const HeterogeneousMap &options = {}) override; diff --git a/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPassVisitor.cpp b/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPassVisitor.cpp new file mode 100644 index 000000000..7ab99f647 --- /dev/null +++ b/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPassVisitor.cpp @@ -0,0 +1,248 @@ +#include "xacc.hpp" +#include "xacc_service.hpp" +#include "Accelerator.hpp" +#include "IonTrapTwoQubitPassVisitor.hpp" + +namespace xacc { +namespace quantum { + IonTrapTwoQubitPassVisitor::IonTrapTwoQubitPassVisitor() + { + _instIdx = -1; // -1 means uninitialized - do we need typedef, enum for this? + _bInstructionVisited = false; + } + + bool IonTrapTwoQubitPassVisitor::instructionVisited() + { + return _bInstructionVisited; + } + + void IonTrapTwoQubitPassVisitor::initializeInstructionVisitor(std::size_t instructionIndex) + { + // initialize _instIdx to the index of the first instruction handled by this visitor + _instIdx = instructionIndex; + // set visited bool to false + _bInstructionVisited = false; + } + + std::pair IonTrapTwoQubitPassVisitor::findMSPhases(IonTrapMSPhaseMap *msPhases, Instruction *cnot) { + std::size_t leftIdx = std::min(cnot->bits()[0], cnot->bits()[1]); + std::size_t rightIdx = std::max(cnot->bits()[0], cnot->bits()[1]); + auto idxPair = std::make_pair(leftIdx, rightIdx); + + if (!msPhases->count(idxPair)) { + xacc::error("pair " + std::to_string(leftIdx) + "," + std::to_string(rightIdx) + + " missing from set of MS phases"); + } + + auto phasePair = (*msPhases)[idxPair]; + double controlMSPhase = (leftIdx == cnot->bits()[0])? phasePair.first : phasePair.second; + double targetMSPhase = (rightIdx == cnot->bits()[1])? phasePair.second : phasePair.first; return std::make_pair(controlMSPhase, targetMSPhase); + + return std::make_pair(controlMSPhase, targetMSPhase); + } + + void IonTrapTwoQubitPassVisitor::visit(CNOT& cnot) { + std::shared_ptr program = _paramsMap.get>("composite"); + std::size_t instIdx = _instIdx; + + // TODO!!! check with others if this name check is acceptaable + // this CNOT visitor is beingtriggerred by an XX instruction, so weed out for now + if (program->getInstruction(instIdx)->name() != std::string("CNOT")) + { + return; + } + + _bInstructionVisited = true; + + xacc::HeterogeneousMap options = _paramsMap.get("options"); + + IonTrapMSPhaseMap *msPhases = options.get("ms-phases"); + IonTrapLogTransformCallback logTransCallback = nullptr; + + auto [controlMSPhase, targetMSPhase] = findMSPhases(msPhases, &cnot); + + if (options.keyExists("log-trans-cb")) { + logTransCallback = options.get("log-trans-cb"); + } + + auto _gateRegistry = xacc::getService("quantum"); + InstPtr ry1 = _gateRegistry->createInstruction("Ry", {cnot.bits()[0]}, {-M_PI/2.0}); + InstPtr xx = _gateRegistry->createInstruction("XX", cnot.bits(), {M_PI/4.0}); + + InstPtr ry2 = _gateRegistry->createInstruction("Ry", {cnot.bits()[0]}, {M_PI/2.0}); + InstPtr rz = _gateRegistry->createInstruction("Rz", {cnot.bits()[0]}, {M_PI/2.0}); + InstPtr rx = _gateRegistry->createInstruction("Rx", {cnot.bits()[1]}, {M_PI/2.0}); + + program->insertInstruction(instIdx++, ry1); + // TODO: Note that this is kind of incorrect: really, the combination of these Rz gates + // and an MS gate is actually an XX gate (see https://doi.org/10.1088/1367-2630/18/2/023048) + // but we are surrounding an XX instruction with Rz instructions. But this will + // work for now + if (controlMSPhase) { + InstPtr msRz1 = _gateRegistry->createInstruction("Rz", {cnot.bits()[0]}, {controlMSPhase}); + program->insertInstruction(instIdx++, msRz1); + } + if (targetMSPhase) { + InstPtr msRz2 = _gateRegistry->createInstruction("Rz", {cnot.bits()[1]}, {targetMSPhase}); + program->insertInstruction(instIdx++, msRz2); + } + program->insertInstruction(instIdx++, xx); + if (controlMSPhase) { + InstPtr msRz3 = _gateRegistry->createInstruction("Rz", {cnot.bits()[0]}, {-controlMSPhase}); + program->insertInstruction(instIdx++, msRz3); + } + if (targetMSPhase) { + InstPtr msRz4 = _gateRegistry->createInstruction("Rz", {cnot.bits()[1]}, {-targetMSPhase}); + program->insertInstruction(instIdx++, msRz4); + } + program->insertInstruction(instIdx++, ry2); + program->insertInstruction(instIdx++, rz); + program->insertInstruction(instIdx++, rx); + + // update the logTransCallback if necessary + if (logTransCallback) { + std::vector newInsts; + for (std::size_t j = _instIdx; j < instIdx; j++) { + newInsts.push_back(program->getInstruction(j)); + } + + // log the current instruction (_instIdx-1), along with newInst tansforms {newInsts} + logTransCallback({program->getInstruction(_instIdx-1)}, {newInsts}); + } + } + + void IonTrapTwoQubitPassVisitor::visit(CH& ch) { + _bInstructionVisited = true; + + auto _gateRegistry = xacc::getService("quantum"); + + std::shared_ptr program = _paramsMap.get>("composite"); + std::size_t instIdx = _instIdx; + xacc::HeterogeneousMap options = _paramsMap.get("options"); + + IonTrapMSPhaseMap *msPhases = options.get("ms-phases"); + IonTrapLogTransformCallback logTransCallback = nullptr; + + if (options.keyExists("log-trans-cb")) { + logTransCallback = options.get("log-trans-cb"); + } + + InstPtr s = _gateRegistry->createInstruction("S", {ch.bits()[1]}); + InstPtr h = _gateRegistry->createInstruction("H", {ch.bits()[1]}); + InstPtr t = _gateRegistry->createInstruction("T", {ch.bits()[1]}); + InstPtr cx = _gateRegistry->createInstruction("CNOT", ch.bits()); + InstPtr tdg = _gateRegistry->createInstruction("Tdg", {ch.bits()[1]}); + InstPtr h2 = _gateRegistry->createInstruction("H", {ch.bits()[1]}); + InstPtr sdg = _gateRegistry->createInstruction("Sdg", {ch.bits()[1]}); + + program->insertInstruction(instIdx, s); + program->insertInstruction(instIdx+1, h); + program->insertInstruction(instIdx+2, t); + program->insertInstruction(instIdx+3, cx); + program->insertInstruction(instIdx+4, tdg); + program->insertInstruction(instIdx+5, h2); + program->insertInstruction(instIdx+6, sdg); + + if (logTransCallback) { + logTransCallback({program->getInstruction(instIdx+7)}, + {program->getInstruction(instIdx), + program->getInstruction(instIdx+1), + program->getInstruction(instIdx+2), + program->getInstruction(instIdx+3), + program->getInstruction(instIdx+4), + program->getInstruction(instIdx+5), + program->getInstruction(instIdx+6)}); + } + } + + void IonTrapTwoQubitPassVisitor::visit(CY& cy) { + _bInstructionVisited = true; + auto _gateRegistry = xacc::getService("quantum"); + std::shared_ptr program = _paramsMap.get>("composite"); + std::size_t instIdx = _instIdx; + xacc::HeterogeneousMap options = _paramsMap.get("options"); + + IonTrapMSPhaseMap *msPhases = options.get("ms-phases"); + IonTrapLogTransformCallback logTransCallback = nullptr; + + if (options.keyExists("log-trans-cb")) { + logTransCallback = options.get("log-trans-cb"); + } + + InstPtr sdg = _gateRegistry->createInstruction("Sdg", {cy.bits()[1]}); + InstPtr cx = _gateRegistry->createInstruction("CNOT", cy.bits()); + InstPtr s = _gateRegistry->createInstruction("S", {cy.bits()[1]}); + + program->insertInstruction(instIdx, sdg); + program->insertInstruction(instIdx+1, cx); + program->insertInstruction(instIdx+2, s); + + if (logTransCallback) { + logTransCallback({program->getInstruction(instIdx+3)}, + {program->getInstruction(instIdx), + program->getInstruction(instIdx+1), + program->getInstruction(instIdx+2)}); + } + } + + void IonTrapTwoQubitPassVisitor::visit(CZ& cz) { + _bInstructionVisited = true; + // get params from hmap + auto _gateRegistry = xacc::getService("quantum"); + std::shared_ptr program = _paramsMap.get>("composite"); + std::size_t instIdx = _instIdx; + xacc::HeterogeneousMap options = _paramsMap.get("options"); + + IonTrapMSPhaseMap *msPhases = options.get("ms-phases"); + IonTrapLogTransformCallback logTransCallback = nullptr; + + if (options.keyExists("log-trans-cb")) { + logTransCallback = options.get("log-trans-cb"); + } + + InstPtr h = _gateRegistry->createInstruction("H", {cz.bits()[1]}); + InstPtr cx = _gateRegistry->createInstruction("CNOT", cz.bits()); + InstPtr h2 = _gateRegistry->createInstruction("H", {cz.bits()[1]}); + + program->insertInstruction(instIdx, h); + program->insertInstruction(instIdx+1, cx); + program->insertInstruction(instIdx+2, h2); + + if (logTransCallback) { + logTransCallback({program->getInstruction(instIdx+3)}, + {program->getInstruction(instIdx), + program->getInstruction(instIdx+1), + program->getInstruction(instIdx+2)}); + } + } + void IonTrapTwoQubitPassVisitor::visit(Swap& swap) { + _bInstructionVisited = true; + auto _gateRegistry = xacc::getService("quantum"); + std::shared_ptr program = _paramsMap.get>("composite"); + std::size_t instIdx = _instIdx; + xacc::HeterogeneousMap options = _paramsMap.get("options"); + + IonTrapMSPhaseMap *msPhases = options.get("ms-phases"); + IonTrapLogTransformCallback logTransCallback = nullptr; + + if (options.keyExists("log-trans-cb")) { + logTransCallback = options.get("log-trans-cb"); + } + + InstPtr cx1 = _gateRegistry->createInstruction("CNOT", swap.bits()); + InstPtr cx2 = _gateRegistry->createInstruction("CNOT", {swap.bits()[1], swap.bits()[0]}); + InstPtr cx3 = _gateRegistry->createInstruction("CNOT", swap.bits()); + + program->insertInstruction(instIdx, cx1); + program->insertInstruction(instIdx+1, cx2); + program->insertInstruction(instIdx+2, cx3); + + if (logTransCallback) { + logTransCallback({program->getInstruction(instIdx+3)}, + {program->getInstruction(instIdx), + program->getInstruction(instIdx+1), + program->getInstruction(instIdx+2)}); + } + } +} // quantum +} // xacc diff --git a/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPassVisitor.hpp b/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPassVisitor.hpp new file mode 100644 index 000000000..42f23915e --- /dev/null +++ b/quantum/plugins/iontrap/transformations/IonTrapTwoQubitPassVisitor.hpp @@ -0,0 +1,55 @@ +/******************************************************************************* + * Copyright (c) 2022 UT-Battelle, LLC. + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Public License v1.0 + * and Eclipse Distribution License v1.0 which accompanies this + * distribution. The Eclipse Public License is available at + * http://www.eclipse.org/legal/epl-v10.html and the Eclipse Distribution + * License is available at https://eclipse.org/org/documents/edl-v10.php + * + * Contributors: + * + * + *******************************************************************************/ +#ifndef QUANTUM_ACCELERATORS_IONTRAPTWOQUBITPASSVISITOR_HPP_ +#define QUANTUM_ACCELERATORS_IONTRAPTWOQUBITPASSVISITOR_HPP_ + +#include "Accelerator.hpp" +#include "IRTransformation.hpp" +#include "IonTrapPassesCommon.hpp" +#include "AllGateVisitor.hpp" +#include "IonTrapTwoQubitPass.hpp" + +using namespace xacc; + +namespace xacc { +namespace quantum { + +class IonTrapTwoQubitPassVisitor : public xacc::quantum::AllGateVisitor { +public: + + IonTrapTwoQubitPassVisitor(); + void visit(CNOT& cnot) override; + void visit(CH& ch) override; + void visit(CY& cy) override; + void visit(CZ& cz) override; + void visit(Swap& Swap) override; + + void initializeInstructionVisitor(std::size_t instructionIndex); + + bool instructionVisited(); + + HeterogeneousMap _paramsMap; + +private: + std::pair findMSPhases(IonTrapMSPhaseMap *, Instruction *); + + std::size_t _instIdx; + bool _bInstructionVisited; + +}; + +} // namespace quantum +} // namespace xacc + +#endif