From 7f5d90d18e190c603c27f44cfce43a2d0374c5d0 Mon Sep 17 00:00:00 2001 From: Douglas Wilson <141026920+doug-q@users.noreply.github.com> Date: Tue, 16 Jul 2024 09:14:06 +0100 Subject: [PATCH] feat: Add lazify-measure pass (#482) Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com> Co-authored-by: Mark Koch <48097969+mark-koch@users.noreply.github.com> --- .github/change-filters.yml | 1 + tket2-hseries/src/extension.rs | 2 +- tket2-hseries/src/extension/quantum_lazy.rs | 161 +++++++++++ tket2-hseries/src/lazify_measure.rs | 284 ++++++++++++++++++++ tket2-hseries/src/lib.rs | 2 + 5 files changed, 449 insertions(+), 1 deletion(-) create mode 100644 tket2-hseries/src/extension/quantum_lazy.rs create mode 100644 tket2-hseries/src/lazify_measure.rs diff --git a/.github/change-filters.yml b/.github/change-filters.yml index a48bb004..f5127387 100644 --- a/.github/change-filters.yml +++ b/.github/change-filters.yml @@ -8,6 +8,7 @@ rust-core: &rust-core rust: - *rust-core + - "tket2-hseries/**" - "badger-optimiser/**" - "compile-rewriter/**" diff --git a/tket2-hseries/src/extension.rs b/tket2-hseries/src/extension.rs index 7ec3fbdc..a7905746 100644 --- a/tket2-hseries/src/extension.rs +++ b/tket2-hseries/src/extension.rs @@ -1,3 +1,3 @@ //! This module defines the Hugr extensions used by tket2-hseries. - pub mod futures; +pub mod quantum_lazy; diff --git a/tket2-hseries/src/extension/quantum_lazy.rs b/tket2-hseries/src/extension/quantum_lazy.rs new file mode 100644 index 00000000..99f930e4 --- /dev/null +++ b/tket2-hseries/src/extension/quantum_lazy.rs @@ -0,0 +1,161 @@ +//! This module defines the Hugr extension used to represent Lazy Quantum +//! Operations. +//! +//! Lazyness is represented by returning `tket2.futures.Future` classical +//! values. Qubits are never lazy. +use hugr::{ + builder::{BuildError, Dataflow}, + extension::{ + prelude::{BOOL_T, QB_T}, + simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp}, + ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, PRELUDE, + }, + ops::{CustomOp, OpType}, + types::FunctionType, + Extension, Wire, +}; + +use lazy_static::lazy_static; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + +use crate::extension::futures; + +use super::futures::future_type; + +/// The "tket2.quantum.lazy" extension id. +pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("tket2.quantum.lazy"); + +lazy_static! { + /// The "tket2.quantum.lazy" extension. + pub static ref EXTENSION: Extension = { + let mut ext = Extension::new(EXTENSION_ID); + LazyQuantumOp::load_all_ops(&mut ext).unwrap(); + ext + }; + + /// Extension registry including the "tket2.quantum.lazy" extension and + /// dependencies. + pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + futures::EXTENSION.to_owned(), + PRELUDE.to_owned(), + EXTENSION.to_owned() + ]).unwrap(); +} + +#[derive( + Clone, + Copy, + Debug, + serde::Serialize, + serde::Deserialize, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + EnumIter, + IntoStaticStr, + EnumString, +)] +#[allow(missing_docs)] +#[non_exhaustive] +pub enum LazyQuantumOp { + Measure, +} + +impl MakeOpDef for LazyQuantumOp { + fn signature(&self) -> SignatureFunc { + match self { + Self::Measure => FunctionType::new(QB_T, vec![QB_T, future_type(BOOL_T)]).into(), + } + } + + fn from_def(op_def: &OpDef) -> Result { + try_from_name(op_def.name(), &EXTENSION_ID) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID + } +} + +impl MakeRegisteredOp for LazyQuantumOp { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + ®ISTRY + } +} + +impl TryFrom<&OpType> for LazyQuantumOp { + type Error = (); + fn try_from(value: &OpType) -> Result { + let Some(custom_op) = value.as_custom_op() else { + Err(())? + }; + match custom_op { + CustomOp::Extension(ext) => Self::from_extension_op(ext).ok(), + CustomOp::Opaque(opaque) => try_from_name(opaque.name(), &EXTENSION_ID).ok(), + } + .ok_or(()) + } +} + +/// An extension trait for [Dataflow] providing methods to add +/// "tket2.quantum.lazy" operations. +pub trait LazyQuantumOpBuilder: Dataflow { + /// Add a "tket2.quantum.lazy.Measure" op. + fn add_lazy_measure(&mut self, qb: Wire) -> Result<[Wire; 2], BuildError> { + Ok(self + .add_dataflow_op(LazyQuantumOp::Measure, [qb])? + .outputs_arr()) + } +} + +impl LazyQuantumOpBuilder for D {} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use cool_asserts::assert_matches; + use futures::FutureOpBuilder as _; + use hugr::{ + builder::{DataflowHugr, FunctionBuilder}, + ops::NamedOp, + }; + use strum::IntoEnumIterator as _; + + use super::*; + + fn get_opdef(op: impl NamedOp) -> Option<&'static Arc> { + EXTENSION.get_op(&op.name()) + } + + #[test] + fn create_extension() { + assert_eq!(EXTENSION.name(), &EXTENSION_ID); + + for o in LazyQuantumOp::iter() { + assert_eq!(LazyQuantumOp::from_def(get_opdef(o).unwrap()), Ok(o)); + } + } + + #[test] + fn circuit() { + let hugr = { + let mut func_builder = + FunctionBuilder::new("circuit", FunctionType::new(QB_T, vec![QB_T, BOOL_T])) + .unwrap(); + let [qb] = func_builder.input_wires_arr(); + let [qb, lazy_b] = func_builder.add_lazy_measure(qb).unwrap(); + let [b] = func_builder.add_read(lazy_b, BOOL_T).unwrap(); + func_builder + .finish_hugr_with_outputs([qb, b], ®ISTRY) + .unwrap() + }; + assert_matches!(hugr.validate(®ISTRY), Ok(_)); + } +} diff --git a/tket2-hseries/src/lazify_measure.rs b/tket2-hseries/src/lazify_measure.rs new file mode 100644 index 00000000..7f0d93dc --- /dev/null +++ b/tket2-hseries/src/lazify_measure.rs @@ -0,0 +1,284 @@ +//! Provides `LazifyMeasurePass` which replaces [Tket2Op::Measure] nodes with +//! [LazyQuantumOp::Measure] nodes. +//! +//! [Tket2Op::Measure]: tket2::Tk2Op::Measure +//! [LazyQuantumOp::Measure]: crate::extension::quantum_lazy::LazyQuantumOp::Measure +use std::collections::{HashMap, HashSet}; + +use hugr::{ + algorithms::validation::ValidationLevel, + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::{ + prelude::{BOOL_T, QB_T}, + ExtensionRegistry, + }, + hugr::{hugrmut::HugrMut, views::SiblingSubgraph, Rewrite}, + types::FunctionType, + Hugr, HugrView, IncomingPort, Node, OutgoingPort, SimpleReplacement, +}; +use tket2::Tk2Op; + +use lazy_static::lazy_static; + +use crate::extension::{ + futures::FutureOpBuilder, + quantum_lazy::{self, LazyQuantumOpBuilder}, +}; + +/// A `Hugr -> Hugr` pass that replaces [tket2::Tk2Op::Measure] nodes with +/// [quantum_lazy::LazyQuantumOp::Measure] nodes. To construct a `LazifyMeasurePass` use +/// [Default::default]. +#[derive(Default)] +pub struct LazifyMeasurePass(ValidationLevel); + +type Error = Box; + +impl LazifyMeasurePass { + /// Run `LazifyMeasurePass` on the given [HugrMut]. `registry` is used for + /// validation, if enabled. + pub fn run(&self, hugr: &mut impl HugrMut, registry: &ExtensionRegistry) -> Result<(), Error> { + self.0 + .run_validated_pass(hugr, registry, |hugr, _validation_level| { + // TODO: if _validation_level is not None, verify no non-local edges + let mut state = + State::new( + hugr.nodes() + .filter_map(|n| match hugr.get_optype(n).try_into() { + Ok(Tk2Op::Measure) => Some(WorkItem::ReplaceMeasure(n)), + _ => None, + }), + ); + while state.work_one(hugr)? {} + Ok(()) + }) + } + + /// Returns a new `LazifyMeasurePass` with the given [ValidationLevel]. + pub fn with_validation_level(mut self, level: ValidationLevel) -> Self { + self.0 = level; + self + } +} + +enum WorkItem { + ReplaceMeasure(Node), +} + +struct State { + worklist: Vec, +} + +impl State { + fn new(items: impl IntoIterator) -> Self { + let worklist = items.into_iter().collect(); + Self { worklist } + } + + fn work_one(&mut self, hugr: &mut impl HugrMut) -> Result { + let Some(item) = self.worklist.pop() else { + return Ok(false); + }; + self.worklist.extend(item.work(hugr)?); + Ok(true) + } +} + +lazy_static! { + static ref MEASURE_READ_HUGR: Hugr = { + let mut builder = DFGBuilder::new(FunctionType::new(QB_T, vec![QB_T, BOOL_T])).unwrap(); + let [qb] = builder.input_wires_arr(); + let [qb, lazy_r] = builder.add_lazy_measure(qb).unwrap(); + let [r] = builder.add_read(lazy_r, BOOL_T).unwrap(); + builder + .finish_hugr_with_outputs([qb, r], &quantum_lazy::REGISTRY) + .unwrap() + }; +} + +fn measure_replacement(num_dups: usize) -> Hugr { + let mut out_types = vec![QB_T]; + out_types.extend((0..num_dups).map(|_| BOOL_T)); + let num_out_types = out_types.len(); + let mut builder = DFGBuilder::new(FunctionType::new(QB_T, out_types)).unwrap(); + let [qb] = builder.input_wires_arr(); + let [qb, mut future_r] = builder.add_lazy_measure(qb).unwrap(); + let mut future_rs = vec![]; + if num_dups > 0 { + for _ in 0..num_dups - 1 { + let [r1, r2] = builder.add_dup(future_r, BOOL_T).unwrap(); + future_rs.push(r1); + future_r = r2; + } + future_rs.push(future_r) + } else { + builder.add_free(future_r, BOOL_T).unwrap(); + } + let mut rs = vec![qb]; + rs.extend( + future_rs + .into_iter() + .map(|r| builder.add_read(r, BOOL_T).unwrap()[0]), + ); + assert_eq!(num_out_types, rs.len()); + assert_eq!(num_out_types, num_dups + 1); + builder + .finish_hugr_with_outputs(rs, &quantum_lazy::REGISTRY) + .unwrap() +} + +fn simple_replace_measure( + hugr: &impl HugrView, + node: Node, +) -> (HashSet<(Node, IncomingPort)>, SimpleReplacement) { + assert!( + hugr.get_optype(node).try_into() == Ok(Tk2Op::Measure), + "{:?}", + hugr.get_optype(node) + ); + let g = SiblingSubgraph::try_from_nodes([node], hugr).unwrap(); + let num_uses_of_bool = hugr.linked_inputs(node, OutgoingPort::from(1)).count(); + let replacement_hugr = measure_replacement(num_uses_of_bool); + let [i, o] = replacement_hugr.get_io(replacement_hugr.root()).unwrap(); + + // A map from (target ports of edges from the Input node of `replacement`) to (target ports of + // edges from nodes not in `removal` to nodes in `removal`). + let nu_inp = replacement_hugr + .all_linked_inputs(i) + .map(|(n, p)| ((n, p), (node, p))) + .collect(); + + // qubit is linear, there must be exactly one + let (target_node, target_port) = hugr + .single_linked_input(node, OutgoingPort::from(0)) + .unwrap(); + // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to + // (input ports of the Output node of `replacement`). + let mut nu_out: HashMap<_, _> = [((target_node, target_port), IncomingPort::from(0))] + .into_iter() + .collect(); + nu_out.extend( + hugr.linked_inputs(node, OutgoingPort::from(1)) + .enumerate() + .map(|(i, target)| (target, IncomingPort::from(i + 1))), + ); + assert_eq!(nu_out.len(), 1 + num_uses_of_bool); + assert_eq!(nu_out.len(), replacement_hugr.in_value_types(o).count()); + + let nu_out_set = nu_out.keys().copied().collect(); + ( + nu_out_set, + SimpleReplacement::new(g, replacement_hugr, nu_inp, nu_out), + ) +} + +impl WorkItem { + fn work(self, hugr: &mut impl HugrMut) -> Result, Error> { + match self { + Self::ReplaceMeasure(node) => { + // for now we read immediately, but when we don't the first + // results are the linked inputs we must return + let (_, replace) = simple_replace_measure(hugr, node); + replace.apply(hugr)?; + Ok(std::iter::empty()) + } + } + } +} + +#[cfg(test)] +mod test { + use cool_asserts::assert_matches; + + use hugr::{ + extension::{ExtensionRegistry, EMPTY_REG, PRELUDE}, + std_extensions::arithmetic::float_types, + }; + use tket2::extension::TKET2_EXTENSION; + + use crate::extension::{ + futures::{self, FutureOp}, + quantum_lazy::LazyQuantumOp, + }; + + use super::*; + + lazy_static! { + pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + quantum_lazy::EXTENSION.to_owned(), + futures::EXTENSION.to_owned(), + TKET2_EXTENSION.to_owned(), + PRELUDE.to_owned(), + float_types::EXTENSION.clone(), + ]) + .unwrap(); + } + #[test] + fn simple() { + let mut hugr = { + let mut builder = DFGBuilder::new(FunctionType::new(QB_T, vec![QB_T, BOOL_T])).unwrap(); + let [qb] = builder.input_wires_arr(); + let outs = builder + .add_dataflow_op(Tk2Op::Measure, [qb]) + .unwrap() + .outputs(); + builder.finish_hugr_with_outputs(outs, ®ISTRY).unwrap() + }; + assert!(hugr.validate_no_extensions(®ISTRY).is_ok()); + LazifyMeasurePass::default() + .run(&mut hugr, &EMPTY_REG) + .unwrap(); + assert!(hugr.validate_no_extensions(®ISTRY).is_ok()); + let mut num_read = 0; + let mut num_lazy_measure = 0; + for n in hugr.nodes() { + let ot = hugr.get_optype(n); + if let Ok(FutureOp::Read) = ot.try_into() { + num_read += 1; + } else if let Ok(LazyQuantumOp::Measure) = ot.try_into() { + num_lazy_measure += 1; + } else { + assert_matches!(Tk2Op::try_from(ot), Err(_)) + } + } + + assert_eq!(1, num_read); + assert_eq!(1, num_lazy_measure); + } + + #[test] + fn multiple_uses() { + let mut builder = + DFGBuilder::new(FunctionType::new(QB_T, vec![QB_T, BOOL_T, BOOL_T])).unwrap(); + let [qb] = builder.input_wires_arr(); + let [qb, bool] = builder + .add_dataflow_op(Tk2Op::Measure, [qb]) + .unwrap() + .outputs_arr(); + let mut hugr = builder + .finish_hugr_with_outputs([qb, bool, bool], ®ISTRY) + .unwrap(); + + assert!(hugr.validate_no_extensions(®ISTRY).is_ok()); + LazifyMeasurePass::default() + .run(&mut hugr, &EMPTY_REG) + .unwrap(); + assert!(hugr.validate_no_extensions(®ISTRY).is_ok()); + } + + #[test] + fn no_uses() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(QB_T)).unwrap(); + let [qb] = builder.input_wires_arr(); + let [qb, _] = builder + .add_dataflow_op(Tk2Op::Measure, [qb]) + .unwrap() + .outputs_arr(); + let mut hugr = builder.finish_hugr_with_outputs([qb], ®ISTRY).unwrap(); + + assert!(hugr.validate_no_extensions(®ISTRY).is_ok()); + LazifyMeasurePass::default() + .run(&mut hugr, &EMPTY_REG) + .unwrap(); + assert!(hugr.validate_no_extensions(®ISTRY).is_ok()); + } +} diff --git a/tket2-hseries/src/lib.rs b/tket2-hseries/src/lib.rs index d640c673..cea19dce 100644 --- a/tket2-hseries/src/lib.rs +++ b/tket2-hseries/src/lib.rs @@ -8,6 +8,8 @@ pub mod cli; pub mod extension; +pub mod lazify_measure; + /// Modify a [Hugr] into a form that is acceptable for ingress into an H-series. /// /// Returns an error if this cannot be done.