From 9fc92daa0bcb20f0aaae21e30eceb0d2f22c3424 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 7 Sep 2023 19:02:24 +0100 Subject: [PATCH 01/12] feat: Hugrs as circuits, no more lifetimes --- Cargo.toml | 4 +-- benches/benchmarks/hash.rs | 4 +-- compile-matcher/src/main.rs | 2 +- src/circuit.rs | 64 +++++++++++++++---------------------- src/circuit/command.rs | 17 +++------- src/circuit/hash.rs | 10 +++--- src/json.rs | 4 +-- src/json/encoder.rs | 2 +- src/ops.rs | 8 ++--- src/portmatching/matcher.rs | 62 +++++++++++++---------------------- src/portmatching/pattern.rs | 20 +++--------- src/portmatching/pyo3.rs | 4 +-- src/rewrite.rs | 2 +- src/rewrite/ecc_rewriter.rs | 2 +- 14 files changed, 77 insertions(+), 128 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dec2ddee..e9d0ce53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,7 +69,7 @@ members = ["pyrs", "compile-matcher"] [workspace.dependencies] -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "5a97a635" } -portgraph = { version = "0.8", features = ["serde"] } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "e23323d" } +portgraph = { version = "0.9", features = ["serde"] } pyo3 = { version = "0.19" } itertools = { version = "0.11.0" } diff --git a/benches/benchmarks/hash.rs b/benches/benchmarks/hash.rs index ad62304c..a9f28c87 100644 --- a/benches/benchmarks/hash.rs +++ b/benches/benchmarks/hash.rs @@ -1,7 +1,7 @@ use criterion::{black_box, criterion_group, AxisScale, BenchmarkId, Criterion, PlotConfiguration}; -use hugr::hugr::views::SiblingGraph; +use hugr::hugr::views::{HierarchyView, SiblingGraph}; use hugr::HugrView; -use tket2::circuit::{CircuitHash, HierarchyView}; +use tket2::circuit::CircuitHash; use super::generators::make_cnot_layers; diff --git a/compile-matcher/src/main.rs b/compile-matcher/src/main.rs index 5c0f59dc..2078a94f 100644 --- a/compile-matcher/src/main.rs +++ b/compile-matcher/src/main.rs @@ -2,7 +2,7 @@ use std::fs; use std::path::Path; use clap::Parser; -use hugr::hugr::views::{HierarchyView, SiblingGraph}; +use hugr::hugr::views::SiblingGraph; use hugr::ops::handle::DfgID; use hugr::HugrView; use itertools::Itertools; diff --git a/src/circuit.rs b/src/circuit.rs index a36fd2da..e7ffc4fc 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -18,12 +18,10 @@ use hugr::hugr::{CircuitUnit, NodeType}; use hugr::ops::OpTrait; use hugr::HugrView; -pub use hugr::hugr::views::HierarchyView; pub use hugr::ops::OpType; use hugr::types::TypeBound; pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Node, Port, Wire}; -use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers}; /// An object behaving like a quantum circuit. // @@ -32,9 +30,11 @@ use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers}; // - Vertical slice iterator // - Gate count map // - Depth -pub trait Circuit<'circ>: HugrView { +pub trait Circuit: HugrView { /// An iterator over the commands in the circuit. - type Commands: Iterator; + type Commands<'a>: Iterator + where + Self: 'a; /// An iterator over the commands applied to an unit. type UnitCommands: Iterator; @@ -67,10 +67,10 @@ pub trait Circuit<'circ>: HugrView { /// Returns all the commands in the circuit, in some topological order. /// /// Ignores the Input and Output nodes. - fn commands(&'circ self) -> Self::Commands; + fn commands(&self) -> Self::Commands<'_>; /// Returns all the commands applied to the given unit, in order. - fn unit_commands(&'circ self) -> Self::UnitCommands; + fn unit_commands(&self) -> Self::UnitCommands; /// Returns the [`NodeType`] of a command. fn command_nodetype(&self, command: &Command) -> &NodeType { @@ -86,12 +86,11 @@ pub trait Circuit<'circ>: HugrView { fn num_gates(&self) -> usize; } -impl<'circ, T> Circuit<'circ> for T +impl Circuit for T where - T: 'circ + HierarchyView<'circ>, - for<'a> &'a T: GraphBase + IntoNeighborsDirected + IntoNodeIdentifiers, + T: HugrView, { - type Commands = CommandIterator<'circ, T>; + type Commands<'a> = CommandIterator<'a, T> where Self: 'a; type UnitCommands = std::iter::Empty; #[inline] @@ -129,12 +128,12 @@ where } } - fn commands(&'circ self) -> Self::Commands { + fn commands(&self) -> Self::Commands<'_> { // Traverse the circuit in topological order. CommandIterator::new(self) } - fn unit_commands(&'circ self) -> Self::UnitCommands { + fn unit_commands(&self) -> Self::UnitCommands { // TODO Can we associate linear i/o with the corresponding unit without // doing the full toposort? unimplemented!() @@ -158,35 +157,24 @@ where #[cfg(test)] mod tests { - use std::sync::OnceLock; - - use hugr::{ - hugr::views::{DescendantsGraph, HierarchyView}, - ops::handle::DfgID, - Hugr, HugrView, - }; + use hugr::Hugr; use crate::{circuit::Circuit, json::load_tk1_json_str}; - static CIRC: OnceLock = OnceLock::new(); - - fn test_circuit() -> DescendantsGraph<'static, DfgID> { - let hugr = CIRC.get_or_init(|| { - load_tk1_json_str( - r#"{ - "phase": "0", - "bits": [], - "qubits": [["q", [0]], ["q", [1]]], - "commands": [ - {"args": [["q", [0]]], "op": {"type": "H"}}, - {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}} - ], - "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] - }"#, - ) - .unwrap() - }); - DescendantsGraph::new(hugr, hugr.root()) + fn test_circuit() -> Hugr { + load_tk1_json_str( + r#"{ + "phase": "0", + "bits": [], + "qubits": [["q", [0]], ["q", [1]]], + "commands": [ + {"args": [["q", [0]]], "op": {"type": "H"}}, + {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}} + ], + "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] + }"#, + ) + .unwrap() } #[test] diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 747a47cb..570c8bb2 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -5,9 +5,7 @@ use std::collections::HashMap; use std::iter::FusedIterator; -use hugr::hugr::views::HierarchyView; use hugr::ops::{OpTag, OpTrait}; -use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers}; use super::Circuit; @@ -59,8 +57,7 @@ pub struct CommandIterator<'circ, Circ> { impl<'circ, Circ> CommandIterator<'circ, Circ> where - Circ: HierarchyView<'circ>, - for<'a> &'a Circ: GraphBase + IntoNeighborsDirected + IntoNodeIdentifiers, + Circ: Circuit, { /// Create a new iterator over the commands of a circuit. pub(super) fn new(circ: &'circ Circ) -> Self { @@ -77,7 +74,7 @@ where }) .collect(); - let nodes = petgraph::algo::toposort(circ, None).unwrap(); + let nodes = petgraph::algo::toposort(&circ.as_petgraph(), None).unwrap(); Self { circ, nodes, @@ -157,8 +154,7 @@ where impl<'circ, Circ> Iterator for CommandIterator<'circ, Circ> where - Circ: HierarchyView<'circ>, - for<'a> &'a Circ: GraphBase + IntoNeighborsDirected + IntoNodeIdentifiers, + Circ: Circuit, { type Item = Command; @@ -182,12 +178,7 @@ where } } -impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> -where - Circ: HierarchyView<'circ>, - for<'a> &'a Circ: GraphBase + IntoNeighborsDirected + IntoNodeIdentifiers, -{ -} +impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> where Circ: Circuit {} #[cfg(test)] mod test { diff --git a/src/circuit/hash.rs b/src/circuit/hash.rs index b62ddd46..e9a0109d 100644 --- a/src/circuit/hash.rs +++ b/src/circuit/hash.rs @@ -4,7 +4,6 @@ use core::panic; use std::hash::{Hash, Hasher}; use fxhash::{FxHashMap, FxHasher64}; -use hugr::hugr::views::HierarchyView; use hugr::ops::{LeafOp, OpName, OpTag, OpTrait, OpType}; use hugr::types::TypeBound; use hugr::{HugrView, Node, Port}; @@ -29,14 +28,15 @@ pub trait CircuitHash<'circ>: HugrView { impl<'circ, T> CircuitHash<'circ> for T where - T: HugrView + HierarchyView<'circ>, - for<'a> &'a T: - pg::GraphBase + pg::IntoNeighborsDirected + pg::IntoNodeIdentifiers, + T: HugrView, { fn circuit_hash(&'circ self) -> u64 { let mut hash_state = HashState::default(); - for node in pg::Topo::new(self).iter(self).filter(|&n| n != self.root()) { + for node in pg::Topo::new(&self.as_petgraph()) + .iter(&self.as_petgraph()) + .filter(|&n| n != self.root()) + { let hash = hash_node(self, node, &mut hash_state); hash_state.set_node(self, node, hash); } diff --git a/src/json.rs b/src/json.rs index 9afc2189..ff9ebc85 100644 --- a/src/json.rs +++ b/src/json.rs @@ -38,7 +38,7 @@ pub trait TKETDecode: Sized { /// Convert the serialized circuit to a [`Hugr`]. fn decode(self) -> Result; /// Convert a [`Hugr`] to a new serialized circuit. - fn encode<'circ>(circuit: &'circ impl Circuit<'circ>) -> Result; + fn encode(circuit: &impl Circuit) -> Result; } impl TKETDecode for SerialCircuit { @@ -60,7 +60,7 @@ impl TKETDecode for SerialCircuit { Ok(decoder.finish()) } - fn encode<'circ>(circ: &'circ impl Circuit<'circ>) -> Result { + fn encode(circ: &impl Circuit) -> Result { let mut encoder = JsonEncoder::new(circ); for com in circ.commands() { let optype = circ.command_optype(&com); diff --git a/src/json/encoder.rs b/src/json/encoder.rs index ad81b653..e087d147 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -40,7 +40,7 @@ pub(super) struct JsonEncoder { impl JsonEncoder { /// Create a new [`JsonEncoder`] from a [`Circuit`]. - pub fn new<'circ>(circ: &impl Circuit<'circ>) -> Self { + pub fn new(circ: &impl Circuit) -> Self { let name = circ.name().map(str::to_string); // Compute the linear qubit and bit registers. Each one have independent diff --git a/src/ops.rs b/src/ops.rs index 8783e698..f29894c5 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -289,12 +289,8 @@ pub(crate) mod test { use std::sync::Arc; - use hugr::{ - extension::OpDef, - hugr::views::{HierarchyView, SiblingGraph}, - ops::handle::DfgID, - Hugr, HugrView, - }; + use hugr::hugr::views::HierarchyView; + use hugr::{extension::OpDef, hugr::views::SiblingGraph, ops::handle::DfgID, Hugr, HugrView}; use rstest::{fixture, rstest}; use crate::{circuit::Circuit, ops::SimpleOpEnum, utils::build_simple_circuit}; diff --git a/src/portmatching/matcher.rs b/src/portmatching/matcher.rs index 4dcd8915..fe3655ff 100644 --- a/src/portmatching/matcher.rs +++ b/src/portmatching/matcher.rs @@ -73,7 +73,7 @@ pub struct PatternMatch<'a, C> { pub(super) root: Node, } -impl<'a, C: Circuit<'a> + Clone> PatternMatch<'a, C> { +impl<'a, C: Circuit + Clone> PatternMatch<'a, C> { /// The matcher's pattern ID of the match. pub fn pattern_id(&self) -> PatternID { self.pattern @@ -187,7 +187,7 @@ impl<'a, C: Circuit<'a> + Clone> PatternMatch<'a, C> { } } -impl<'a, C: Circuit<'a>> Debug for PatternMatch<'a, C> { +impl<'a, C: Circuit> Debug for PatternMatch<'a, C> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PatternMatch") .field("root", &self.root) @@ -237,10 +237,7 @@ impl PatternMatcher { } /// Find all convex pattern matches in a circuit. - pub fn find_matches<'a, C: Circuit<'a> + Clone>( - &self, - circuit: &'a C, - ) -> Vec> { + pub fn find_matches<'a, C: Circuit + Clone>(&self, circuit: &'a C) -> Vec> { let mut checker = ConvexChecker::new(circuit); circuit .commands() @@ -249,7 +246,7 @@ impl PatternMatcher { } /// Find all convex pattern matches in a circuit rooted at a given node. - fn find_rooted_matches<'a, C: Circuit<'a> + Clone>( + fn find_rooted_matches<'a, C: Circuit + Clone>( &self, circ: &'a C, root: Node, @@ -385,8 +382,8 @@ fn compatible_offsets((_, pout): &(Port, Port), (pin, _): &(Port, Port)) -> bool } /// Check if an edge `e` is valid in a portgraph `g` without weights. -pub(crate) fn validate_unweighted_edge<'circ>( - circ: &impl Circuit<'circ>, +pub(crate) fn validate_unweighted_edge( + circ: &impl Circuit, ) -> impl for<'a> Fn(Node, &'a PEdge) -> Option + '_ { move |src, &(src_port, tgt_port)| { let (next_node, _) = circ @@ -397,8 +394,8 @@ pub(crate) fn validate_unweighted_edge<'circ>( } /// Check if a node `n` is valid in a weighted portgraph `g`. -pub(crate) fn validate_weighted_node<'circ>( - circ: &impl Circuit<'circ>, +pub(crate) fn validate_weighted_node( + circ: &impl Circuit, ) -> impl for<'a> Fn(Node, &PNode) -> bool + '_ { move |v, prop| { let v_weight = MatchOp::try_from(circ.get_optype(v).clone()); @@ -425,11 +422,7 @@ fn handle_match_error(match_res: Result, root: Node) #[cfg(test)] mod tests { - use std::sync::OnceLock; - - use hugr::hugr::views::{DescendantsGraph, HierarchyView}; - use hugr::ops::handle::DfgID; - use hugr::{Hugr, HugrView}; + use hugr::Hugr; use itertools::Itertools; use crate::utils::build_simple_circuit; @@ -437,31 +430,22 @@ mod tests { use super::{CircuitPattern, PatternMatcher}; - static H_CX: OnceLock = OnceLock::new(); - static CX_CX: OnceLock = OnceLock::new(); - - fn h_cx<'a>() -> DescendantsGraph<'a, DfgID> { - let circ = H_CX.get_or_init(|| { - build_simple_circuit(2, |circ| { - circ.append(T2Op::CX, [0, 1]).unwrap(); - circ.append(T2Op::H, [0]).unwrap(); - Ok(()) - }) - .unwrap() - }); - DescendantsGraph::new(circ, circ.root()) + fn h_cx() -> Hugr { + build_simple_circuit(2, |circ| { + circ.append(T2Op::CX, [0, 1]).unwrap(); + circ.append(T2Op::H, [0]).unwrap(); + Ok(()) + }) + .unwrap() } - fn cx_xc<'a>() -> DescendantsGraph<'a, DfgID> { - let circ = CX_CX.get_or_init(|| { - build_simple_circuit(2, |circ| { - circ.append(T2Op::CX, [0, 1]).unwrap(); - circ.append(T2Op::CX, [1, 0]).unwrap(); - Ok(()) - }) - .unwrap() - }); - DescendantsGraph::new(circ, circ.root()) + fn cx_xc() -> Hugr { + build_simple_circuit(2, |circ| { + circ.append(T2Op::CX, [0, 1]).unwrap(); + circ.append(T2Op::CX, [1, 0]).unwrap(); + Ok(()) + }) + .unwrap() } #[test] diff --git a/src/portmatching/pattern.rs b/src/portmatching/pattern.rs index 36e3b56f..0d123f10 100644 --- a/src/portmatching/pattern.rs +++ b/src/portmatching/pattern.rs @@ -33,9 +33,7 @@ impl CircuitPattern { } /// Construct a pattern from a circuit. - pub fn try_from_circuit<'circ, C: Circuit<'circ>>( - circuit: &'circ C, - ) -> Result { + pub fn try_from_circuit(circuit: &C) -> Result { if circuit.num_gates() == 0 { return Err(InvalidPattern::EmptyCircuit); } @@ -79,11 +77,7 @@ impl CircuitPattern { } /// Compute the map from pattern nodes to circuit nodes in `circ`. - pub fn get_match_map<'a, C: Circuit<'a>>( - &self, - root: Node, - circ: &C, - ) -> Option> { + pub fn get_match_map(&self, root: Node, circ: &C) -> Option> { let single_matcher = SinglePatternMatcher::from_pattern(self.pattern.clone()); single_matcher .get_match_map( @@ -121,9 +115,7 @@ impl From for InvalidPattern { #[cfg(test)] mod tests { - use hugr::hugr::views::{DescendantsGraph, HierarchyView, SiblingGraph}; - use hugr::ops::handle::DfgID; - use hugr::{Hugr, HugrView}; + use hugr::Hugr; use itertools::Itertools; use crate::utils::build_simple_circuit; @@ -143,9 +135,8 @@ mod tests { #[test] fn construct_pattern() { let hugr = h_cx(); - let circ: DescendantsGraph<'_, DfgID> = DescendantsGraph::new(&hugr, hugr.root()); - let p = CircuitPattern::try_from_circuit(&circ).unwrap(); + let p = CircuitPattern::try_from_circuit(&hugr).unwrap(); let edges = p .pattern @@ -163,13 +154,12 @@ mod tests { #[test] fn disconnected_pattern() { - let hugr = build_simple_circuit(2, |circ| { + let circ = build_simple_circuit(2, |circ| { circ.append(T2Op::X, [0])?; circ.append(T2Op::T, [1])?; Ok(()) }) .unwrap(); - let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(&hugr, hugr.root()); assert_eq!( CircuitPattern::try_from_circuit(&circ).unwrap_err(), InvalidPattern::NotConnected diff --git a/src/portmatching/pyo3.rs b/src/portmatching/pyo3.rs index 41faed6c..a08e90a3 100644 --- a/src/portmatching/pyo3.rs +++ b/src/portmatching/pyo3.rs @@ -121,8 +121,8 @@ impl PyPatternMatch { /// /// Requires references to the circuit and pattern to resolve indices /// into these objects. - pub fn try_from_rust<'circ, C: Circuit<'circ> + Clone>( - m: PatternMatch<'circ, C>, + pub fn try_from_rust( + m: PatternMatch, circ: &C, matcher: &PatternMatcher, ) -> PyResult { diff --git a/src/rewrite.rs b/src/rewrite.rs index d4919878..07f0ee7f 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -52,5 +52,5 @@ impl CircuitRewrite { /// Generate rewrite rules for circuits. pub trait Rewriter { /// Get the rewrite rules for a circuit. - fn get_rewrites<'a, C: Circuit<'a> + Clone>(&'a self, circ: &'a C) -> Vec; + fn get_rewrites<'a, C: Circuit + Clone>(&'a self, circ: &'a C) -> Vec; } diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index 32d181ed..af4a3eeb 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -96,7 +96,7 @@ impl ECCRewriter { } impl Rewriter for ECCRewriter { - fn get_rewrites<'a, C: Circuit<'a> + Clone>(&'a self, circ: &'a C) -> Vec { + fn get_rewrites<'a, C: Circuit + Clone>(&'a self, circ: &'a C) -> Vec { let matches = self.matcher.find_matches(circ); matches .into_iter() From ed53f1738c1427e3e2816b2079556295c98bded9 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 7 Sep 2023 22:24:40 +0100 Subject: [PATCH 02/12] Rework Circuit unit iterators, expand API --- compile-matcher/src/main.rs | 3 +- src/circuit.rs | 189 ++++++++++++++++-------------------- src/circuit/command.rs | 21 +++- src/circuit/units.rs | 109 +++++++++++++++++++++ src/json/tests.rs | 2 +- 5 files changed, 213 insertions(+), 111 deletions(-) create mode 100644 src/circuit/units.rs diff --git a/compile-matcher/src/main.rs b/compile-matcher/src/main.rs index 2078a94f..61c4e188 100644 --- a/compile-matcher/src/main.rs +++ b/compile-matcher/src/main.rs @@ -65,9 +65,8 @@ fn main() { let patterns = all_circs .iter() .filter_map(|circ| { - let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(&circ, circ.root()); // Fail silently on empty or disconnected patterns - CircuitPattern::try_from_circuit(&circ).ok() + CircuitPattern::try_from_circuit(circ).ok() }) .collect_vec(); println!("Loaded {} patterns.", patterns.len()); diff --git a/src/circuit.rs b/src/circuit.rs index e7ffc4fc..2abfcf73 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -2,159 +2,138 @@ pub mod command; mod hash; +mod units; +pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; -//#[cfg(feature = "pyo3")] -//pub mod py_circuit; - -//#[cfg(feature = "tkcxx")] -//pub mod unitarybox; - -use self::command::{Command, CommandIterator}; - -use hugr::extension::prelude::QB_T; -use hugr::hugr::{CircuitUnit, NodeType}; -use hugr::ops::OpTrait; +use hugr::hugr::NodeType; use hugr::HugrView; pub use hugr::ops::OpType; -use hugr::types::TypeBound; +use hugr::types::FunctionType; pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Node, Port, Wire}; +use self::units::{UnitType, Units}; + /// An object behaving like a quantum circuit. // // TODO: More methods: // - other_{in,out}puts (for non-linear i/o + const inputs)? // - Vertical slice iterator -// - Gate count map // - Depth pub trait Circuit: HugrView { - /// An iterator over the commands in the circuit. - type Commands<'a>: Iterator - where - Self: 'a; - - /// An iterator over the commands applied to an unit. - type UnitCommands: Iterator; - /// Return the name of the circuit - fn name(&self) -> Option<&str>; - - /// Get the linear inputs of the circuit and their types. - fn units(&self) -> Vec<(CircuitUnit, Type)>; - - /// Returns the units corresponding to qubits inputs to the circuit. #[inline] - fn qubits(&self) -> Vec { - self.units() - .iter() - .filter(|(_, typ)| typ == &QB_T) - .map(|(unit, _)| *unit) - .collect() + fn name(&self) -> Option<&str> { + let meta = self.get_metadata(self.root()).as_object()?; + meta.get("name")?.as_str() } - /// Returns the input node to the circuit. - fn input(&self) -> Node; - - /// Returns the output node to the circuit. - fn output(&self) -> Node; - - /// Given a linear port in a node, returns the corresponding port on the other side of the node (if any). - fn follow_linear_port(&self, node: Node, port: Port) -> Option; - - /// Returns all the commands in the circuit, in some topological order. + /// Returns the function type of the circuit. /// - /// Ignores the Input and Output nodes. - fn commands(&self) -> Self::Commands<'_>; - - /// Returns all the commands applied to the given unit, in order. - fn unit_commands(&self) -> Self::UnitCommands; + /// Equivalent to [`HugrView::get_function_type`]. + #[inline] + fn circuit_signature(&self) -> &FunctionType { + self.get_function_type() + .expect("Circuit has no function type") + } - /// Returns the [`NodeType`] of a command. - fn command_nodetype(&self, command: &Command) -> &NodeType { - self.get_nodetype(command.node()) + /// Returns the input node to the circuit. + #[inline] + fn input(&self) -> Node { + return self + .children(self.root()) + .next() + .expect("Circuit has no input node"); } - /// Returns the [`OpType`] of a command. - fn command_optype(&self, command: &Command) -> &OpType { - self.get_optype(command.node()) + /// Returns the output node to the circuit. + #[inline] + fn output(&self) -> Node { + return self + .children(self.root()) + .nth(1) + .expect("Circuit has no output node"); } /// The number of gates in the circuit. - fn num_gates(&self) -> usize; -} - -impl Circuit for T -where - T: HugrView, -{ - type Commands<'a> = CommandIterator<'a, T> where Self: 'a; - type UnitCommands = std::iter::Empty; - #[inline] - fn name(&self) -> Option<&str> { - let meta = self.get_metadata(self.root()).as_object()?; - meta.get("name")?.as_str() + fn num_gates(&self) -> usize { + self.children(self.root()).count() - 2 } + /// Count the number of qubits in the circuit. #[inline] - fn units(&self) -> Vec<(CircuitUnit, Type)> { - let root = self.root(); - let optype = self.get_optype(root); - optype - .signature() - .input_types() - .iter() - .filter(|&typ| !TypeBound::Copyable.contains(typ.least_upper_bound())) - .enumerate() - .map(|(i, typ)| (i.into(), typ.clone())) - .collect() + fn qubit_count(&self) -> usize + where + Self: Sized, + { + self.qubits().count() } - fn follow_linear_port(&self, node: Node, port: Port) -> Option { - let optype = self.get_optype(node); - if !optype.port_kind(port)?.is_linear() { - return None; - } - // TODO: We assume the linear data uses the same port offsets on both sides of the node. - // In the future we may want to have a more general mechanism to handle this. - let other_port = Port::new(port.direction().reverse(), port.index()); - if optype.port_kind(other_port) == optype.port_kind(port) { - Some(other_port) - } else { - None - } + /// Get the input units of the circuit and their types. + #[inline] + fn units(&self) -> Units<'_> + where + Self: Sized, + { + Units::new(self, UnitType::All) } - fn commands(&self) -> Self::Commands<'_> { - // Traverse the circuit in topological order. - CommandIterator::new(self) + /// Get the linear input units of the circuit and their types. + #[inline] + fn linear_units(&self) -> Units<'_> + where + Self: Sized, + { + Units::new(self, UnitType::Linear) } - fn unit_commands(&self) -> Self::UnitCommands { - // TODO Can we associate linear i/o with the corresponding unit without - // doing the full toposort? - unimplemented!() + /// Get the linear input units of the circuit and their types. + #[inline] + fn nonlinear_units(&self) -> Units<'_> + where + Self: Sized, + { + Units::new(self, UnitType::NonLinear) } + /// Returns the units corresponding to qubits inputs to the circuit. #[inline] - fn input(&self) -> Node { - return self.children(self.root()).next().unwrap(); + fn qubits(&self) -> Units<'_> + where + Self: Sized, + { + Units::new(self, UnitType::Qubits) } + /// Returns all the commands in the circuit, in some topological order. + /// + /// Ignores the Input and Output nodes. #[inline] - fn output(&self) -> Node { - return self.children(self.root()).nth(1).unwrap(); + fn commands(&self) -> CommandIterator<'_, Self> + where + Self: Sized, + { + // Traverse the circuit in topological order. + CommandIterator::new(self) } - #[inline] - fn num_gates(&self) -> usize { - self.children(self.root()).count() - 2 + /// Returns the [`NodeType`] of a command. + fn command_nodetype(&self, command: &Command) -> &NodeType { + self.get_nodetype(command.node()) + } + + /// Returns the [`OpType`] of a command. + fn command_optype(&self, command: &Command) -> &OpType { + self.get_optype(command.node()) } } +impl Circuit for T where T: HugrView {} + #[cfg(test)] mod tests { use hugr::Hugr; diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 570c8bb2..e41e21fa 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -67,9 +67,9 @@ where .node_outputs(circ.input()) .map(|port| Wire::new(circ.input(), port)); let wire_unit = input_node_wires - .zip(circ.units().iter()) + .zip(circ.linear_units()) .filter_map(|(wire, (unit, _))| match unit { - CircuitUnit::Linear(i) => Some((wire, *i)), + CircuitUnit::Linear(i) => Some((wire, i)), _ => None, }) .collect(); @@ -111,7 +111,7 @@ where // Get the unit corresponding to a wire, or return a wire Unit. match self.wire_unit.remove(&wire) { Some(unit) => { - if let Some(new_port) = self.circ.follow_linear_port(node, port) { + if let Some(new_port) = self.follow_linear_port(node, port) { self.wire_unit.insert(Wire::new(node, new_port), unit); } Some(CircuitUnit::Linear(unit)) @@ -150,6 +150,21 @@ where outputs, }) } + + fn follow_linear_port(&self, node: Node, port: Port) -> Option { + let optype = self.circ.get_optype(node); + if !optype.port_kind(port)?.is_linear() { + return None; + } + // TODO: We assume the linear data uses the same port offsets on both sides of the node. + // In the future we may want to have a more general mechanism to handle this. + let other_port = Port::new(port.direction().reverse(), port.index()); + if optype.port_kind(other_port) == optype.port_kind(port) { + Some(other_port) + } else { + None + } + } } impl<'circ, Circ> Iterator for CommandIterator<'circ, Circ> diff --git a/src/circuit/units.rs b/src/circuit/units.rs new file mode 100644 index 00000000..bc5d3c30 --- /dev/null +++ b/src/circuit/units.rs @@ -0,0 +1,109 @@ +//! Iterators over the units of a circuit. + +use std::iter::FusedIterator; + +use hugr::extension::prelude; +use hugr::hugr::CircuitUnit; +use hugr::types::{Type, TypeBound, TypeRow}; +use hugr::{Node, Port, Wire}; + +use super::Circuit; + +/// An iterator over the units of a circuit. +pub struct Units<'a> { + /// Whether to only + output_mode: UnitType, + /// The inputs to the circuit + inputs: Option<&'a TypeRow>, + /// Input node of the circuit + input_node: Node, + /// The current index in the inputs + current: usize, + /// The amount of linear units yielded. + linear_count: usize, +} + +impl<'a> Units<'a> { + /// Create a new iterator over the units of a circuit. + pub(super) fn new(circuit: &'a impl Circuit, output_mode: UnitType) -> Self { + Self { + output_mode, + inputs: circuit.get_function_type().map(|ft| &ft.input), + input_node: circuit.input(), + current: 0, + linear_count: 0, + } + } + + /// Construct an output value to yield. + fn make_value(&self, typ: &Type, input_port: Port) -> (CircuitUnit, Type) { + match type_is_linear(typ) { + true => (CircuitUnit::Linear(self.linear_count - 1), typ.clone()), + false => ( + CircuitUnit::Wire(Wire::new(self.input_node, input_port)), + typ.clone(), + ), + } + } +} + +impl<'a> Iterator for Units<'a> { + type Item = (CircuitUnit, Type); + + fn next(&mut self) -> Option { + let inputs = self.inputs?; + loop { + let typ = inputs.get(self.current)?; + let input_port = Port::new_outgoing(self.current); + self.current += 1; + if type_is_linear(typ) { + self.linear_count += 1; + } + if self.output_mode.accept(typ) { + return Some(self.make_value(typ, input_port)); + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self + .inputs + .map(|inputs| inputs.len() - self.current) + .unwrap_or(0); + match self.output_mode { + UnitType::All => (len, Some(len)), + _ => (0, Some(len)), + } + } +} + +impl<'a> FusedIterator for Units<'a> {} + +/// What kind of units to iterate over. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum UnitType { + /// All units. + All, + /// Only the linear units. + Linear, + /// Only the qubit units. + Qubits, + /// Only the non-linear units. + NonLinear, +} + +impl UnitType { + /// Check if a [`Type`] should be yielded. + pub fn accept(self, typ: &Type) -> bool { + match self { + UnitType::All => true, + UnitType::Linear => type_is_linear(typ), + UnitType::Qubits => *typ == prelude::QB_T, + UnitType::NonLinear => !type_is_linear(typ), + } + } +} + +fn type_is_linear(typ: &Type) -> bool { + !TypeBound::Copyable.contains(typ.least_upper_bound()) +} diff --git a/src/json/tests.rs b/src/json/tests.rs index 7261c302..76f7cb9c 100644 --- a/src/json/tests.rs +++ b/src/json/tests.rs @@ -61,7 +61,7 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num let hugr: Hugr = ser.clone().decode().unwrap(); let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(&hugr, hugr.root()); - assert_eq!(circ.qubits().len(), num_qubits); + assert_eq!(circ.qubit_count(), num_qubits); let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); compare_serial_circs(&ser, &reser); From 1f7c0b7d8f9c46bb87df9a8ea974acca80f0b531 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 7 Sep 2023 22:32:11 +0100 Subject: [PATCH 03/12] Test the iterators --- src/circuit.rs | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 2abfcf73..55fcb83b 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -144,11 +144,12 @@ mod tests { load_tk1_json_str( r#"{ "phase": "0", - "bits": [], + "bits": [["c", [0]]], "qubits": [["q", [0]], ["q", [1]]], "commands": [ {"args": [["q", [0]]], "op": {"type": "H"}}, - {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}} + {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}}, + {"args": [["q", [1]]], "op": {"type": "X"}} ], "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] }"#, @@ -157,8 +158,21 @@ mod tests { } #[test] - fn test_num_gates() { + fn test_circuit_properties() { let circ = test_circuit(); - assert_eq!(circ.num_gates(), 2); + + assert_eq!(circ.name(), None); + assert_eq!(circ.circuit_signature().input.len(), 3); + assert_eq!(circ.circuit_signature().output.len(), 3); + assert_eq!(circ.qubit_count(), 2); + assert_eq!(circ.num_gates(), 3); + + assert_eq!(circ.units().count(), 3); + assert_eq!(circ.nonlinear_units().count(), 0); + assert_eq!(circ.linear_units().count(), 3); + assert_eq!(circ.qubits().count(), 2); + + assert!(circ.linear_units().all(|(unit, _)| unit.is_linear())); + assert!(circ.nonlinear_units().all(|(unit, _)| unit.is_wire())); } } From 2cb8c9e49a19b3c5c03773234b03b2c14d3d3dc5 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 7 Sep 2023 22:57:42 +0100 Subject: [PATCH 04/12] fix doc --- src/circuit.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/circuit.rs b/src/circuit.rs index 55fcb83b..d5c46ec3 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -91,7 +91,7 @@ pub trait Circuit: HugrView { Units::new(self, UnitType::Linear) } - /// Get the linear input units of the circuit and their types. + /// Get the non-linear input units of the circuit and their types. #[inline] fn nonlinear_units(&self) -> Units<'_> where From 1e2d4ba6e730492254270f404c45489611b34f16 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 11 Sep 2023 16:39:52 +0100 Subject: [PATCH 05/12] Use the Topo iterator for CommandIter --- src/circuit/command.rs | 43 ++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/src/circuit/command.rs b/src/circuit/command.rs index e41e21fa..ef1f50fc 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -2,10 +2,11 @@ //! //! A [`Command`] is an operation applied to an specific wires, possibly identified by their index in the circuit's input vector. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::iter::FusedIterator; use hugr::ops::{OpTag, OpTrait}; +use petgraph::visit as pv; use super::Circuit; @@ -42,17 +43,29 @@ impl Command { } } +type NodeWalker = pv::Topo>; + /// An iterator over the commands of a circuit. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CommandIterator<'circ, Circ> { - /// The circuit + /// The circuit. circ: &'circ Circ, - /// Toposorted nodes - nodes: Vec, - /// Current element in `nodes` - current: usize, - /// Last wires for each linear `CircuitUnit` + /// Toposorted nodes. + nodes: NodeWalker, + /// Last wires for each linear `CircuitUnit`. wire_unit: HashMap, + /// Remaining commands, not counting I/O nodes. + remaining: usize, +} + +impl<'circ, Circ: std::fmt::Debug> std::fmt::Debug for CommandIterator<'circ, Circ> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CommandIterator") + .field("circ", &self.circ) + .field("wire_unit", &self.wire_unit) + .field("remaining", &self.remaining) + .finish() + } } impl<'circ, Circ> CommandIterator<'circ, Circ> @@ -74,12 +87,12 @@ where }) .collect(); - let nodes = petgraph::algo::toposort(&circ.as_petgraph(), None).unwrap(); + let nodes = pv::Topo::new(&circ.as_petgraph()); Self { circ, nodes, - current: 0, wire_unit, + remaining: circ.node_count() - 2, } } @@ -173,15 +186,13 @@ where { type Item = Command; + #[inline] fn next(&mut self) -> Option { loop { - if self.current == self.nodes.len() { - return None; - } - let node = self.nodes[self.current]; + let node = self.nodes.next(&self.circ.as_petgraph())?; let com = self.process_node(node); - self.current += 1; if com.is_some() { + self.remaining -= 1; return com; } } @@ -189,7 +200,7 @@ where #[inline] fn size_hint(&self) -> (usize, Option) { - (0, Some(self.nodes.len() - self.current)) + (self.remaining, Some(self.remaining)) } } From fc7acbc6684a8fbcad8f0130451ef3f5f9c851fa Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 11 Sep 2023 18:14:51 +0100 Subject: [PATCH 06/12] Allow arbitrary nodes in `Units` - Also, optionally return the ports for each unit --- src/circuit.rs | 16 +++---- src/circuit/command.rs | 1 + src/circuit/units.rs | 104 +++++++++++++++++++++++++++++++---------- 3 files changed, 89 insertions(+), 32 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index d5c46ec3..40cc9518 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -75,38 +75,38 @@ pub trait Circuit: HugrView { /// Get the input units of the circuit and their types. #[inline] - fn units(&self) -> Units<'_> + fn units(&self) -> Units where Self: Sized, { - Units::new(self, UnitType::All) + Units::new_input(self, UnitType::All) } /// Get the linear input units of the circuit and their types. #[inline] - fn linear_units(&self) -> Units<'_> + fn linear_units(&self) -> Units where Self: Sized, { - Units::new(self, UnitType::Linear) + Units::new_input(self, UnitType::Linear) } /// Get the non-linear input units of the circuit and their types. #[inline] - fn nonlinear_units(&self) -> Units<'_> + fn nonlinear_units(&self) -> Units where Self: Sized, { - Units::new(self, UnitType::NonLinear) + Units::new_input(self, UnitType::NonLinear) } /// Returns the units corresponding to qubits inputs to the circuit. #[inline] - fn qubits(&self) -> Units<'_> + fn qubits(&self) -> Units where Self: Sized, { - Units::new(self, UnitType::Qubits) + Units::new_input(self, UnitType::Qubits) } /// Returns all the commands in the circuit, in some topological order. diff --git a/src/circuit/command.rs b/src/circuit/command.rs index ef1f50fc..64be0f4c 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -43,6 +43,7 @@ impl Command { } } +/// A non-borrowing topological walker over the nodes of a circuit. type NodeWalker = pv::Topo>; /// An iterator over the commands of a circuit. diff --git a/src/circuit/units.rs b/src/circuit/units.rs index bc5d3c30..aef21dd2 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -4,56 +4,89 @@ use std::iter::FusedIterator; use hugr::extension::prelude; use hugr::hugr::CircuitUnit; +use hugr::ops::OpTrait; use hugr::types::{Type, TypeBound, TypeRow}; use hugr::{Node, Port, Wire}; use super::Circuit; -/// An iterator over the units of a circuit. -pub struct Units<'a> { - /// Whether to only +/// An iterator over the units at the output of a [Node]. +#[derive(Clone, Debug)] +pub struct Units { + /// Whether to only. output_mode: UnitType, - /// The inputs to the circuit - inputs: Option<&'a TypeRow>, - /// Input node of the circuit - input_node: Node, - /// The current index in the inputs - current: usize, + /// The types of the node outputs. + // + // TODO: We could avoid cloning the TypeRow if `OpType::signature` returned + // a reference. + node_output_types: TypeRow, + /// The node of the circuit. + node: Node, + /// The current index in the inputs. + pub(self) current: usize, /// The amount of linear units yielded. linear_count: usize, } -impl<'a> Units<'a> { - /// Create a new iterator over the units of a circuit. - pub(super) fn new(circuit: &'a impl Circuit, output_mode: UnitType) -> Self { +impl Units { + /// Create a new iterator over the units of a node. + // + // FIXME: Currently this ignores any incoming linear unit labels, and just + // assigns new ids sequentially. + #[inline] + #[allow(unused)] + pub(super) fn new(circuit: &impl Circuit, node: Node, output_mode: UnitType) -> Self { Self { output_mode, - inputs: circuit.get_function_type().map(|ft| &ft.input), - input_node: circuit.input(), + node_output_types: circuit.get_optype(node).signature().output, + node, current: 0, linear_count: 0, } } + /// Create a new iterator over the input units of a circuit. + /// + /// This iterator will yield all units originating from the circuit's input + /// node. + #[inline] + pub(super) fn new_input(circuit: &impl Circuit, output_mode: UnitType) -> Self { + Self { + output_mode, + node_output_types: circuit + .get_function_type() + .map_or_else(Default::default, |ft| ft.input.clone()), + node: circuit.input(), + current: 0, + linear_count: 0, + } + } + + /// Add the corresponding ports to the iterator output. + #[inline] + pub fn with_ports(self) -> UnitPorts { + UnitPorts { units: self } + } + /// Construct an output value to yield. + #[inline] fn make_value(&self, typ: &Type, input_port: Port) -> (CircuitUnit, Type) { match type_is_linear(typ) { true => (CircuitUnit::Linear(self.linear_count - 1), typ.clone()), false => ( - CircuitUnit::Wire(Wire::new(self.input_node, input_port)), + CircuitUnit::Wire(Wire::new(self.node, input_port)), typ.clone(), ), } } } -impl<'a> Iterator for Units<'a> { +impl Iterator for Units { type Item = (CircuitUnit, Type); fn next(&mut self) -> Option { - let inputs = self.inputs?; loop { - let typ = inputs.get(self.current)?; + let typ = self.node_output_types.get(self.current)?; let input_port = Port::new_outgoing(self.current); self.current += 1; if type_is_linear(typ) { @@ -66,10 +99,7 @@ impl<'a> Iterator for Units<'a> { } fn size_hint(&self) -> (usize, Option) { - let len = self - .inputs - .map(|inputs| inputs.len() - self.current) - .unwrap_or(0); + let len = self.node_output_types.len() - self.current; match self.output_mode { UnitType::All => (len, Some(len)), _ => (0, Some(len)), @@ -77,12 +107,38 @@ impl<'a> Iterator for Units<'a> { } } -impl<'a> FusedIterator for Units<'a> {} +impl FusedIterator for Units {} + +/// An iterator over the units of a circuit, including their [`Port`]s. +/// +/// A simple wrapper around [`Units`]. +#[repr(transparent)] +pub struct UnitPorts { + /// The internal Units iterator. + units: Units, +} + +impl Iterator for UnitPorts { + type Item = (CircuitUnit, Port, Type); + + fn next(&mut self) -> Option { + let port = Port::new_outgoing(self.units.current); + let (unit, typ) = self.units.next()?; + Some((unit, port, typ)) + } + + fn size_hint(&self) -> (usize, Option) { + self.units.size_hint() + } +} + +impl FusedIterator for UnitPorts {} /// What kind of units to iterate over. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub(super) enum UnitType { /// All units. + #[default] All, /// Only the linear units. Linear, From a3e0ba68634a87cbc245270eced3b5ac52b1a675 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 11 Sep 2023 19:12:46 +0100 Subject: [PATCH 07/12] Add a linear unit assigner to Units --- src/circuit.rs | 4 +- src/circuit/command.rs | 8 +-- src/circuit/units.rs | 113 ++++++++++++++++++++++------------------- src/json/encoder.rs | 2 +- 4 files changed, 69 insertions(+), 58 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 40cc9518..bc8b1d6b 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -172,7 +172,7 @@ mod tests { assert_eq!(circ.linear_units().count(), 3); assert_eq!(circ.qubits().count(), 2); - assert!(circ.linear_units().all(|(unit, _)| unit.is_linear())); - assert!(circ.nonlinear_units().all(|(unit, _)| unit.is_wire())); + assert!(circ.linear_units().all(|(unit, _, _)| unit.is_linear())); + assert!(circ.nonlinear_units().all(|(unit, _, _)| unit.is_wire())); } } diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 64be0f4c..28ab62e0 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -82,7 +82,7 @@ where .map(|port| Wire::new(circ.input(), port)); let wire_unit = input_node_wires .zip(circ.linear_units()) - .filter_map(|(wire, (unit, _))| match unit { + .filter_map(|(wire, (unit, _, _))| match unit { CircuitUnit::Linear(i) => Some((wire, i)), _ => None, }) @@ -165,13 +165,15 @@ where }) } + /// Returns the linear port on the node that corresponds to the same linear unit. + /// + /// We assume the linear data uses the same port offsets on both sides of the node. + /// In the future we may want to have a more general mechanism to handle this. fn follow_linear_port(&self, node: Node, port: Port) -> Option { let optype = self.circ.get_optype(node); if !optype.port_kind(port)?.is_linear() { return None; } - // TODO: We assume the linear data uses the same port offsets on both sides of the node. - // In the future we may want to have a more general mechanism to handle this. let other_port = Port::new(port.direction().reverse(), port.index()); if optype.port_kind(other_port) == optype.port_kind(port) { Some(other_port) diff --git a/src/circuit/units.rs b/src/circuit/units.rs index aef21dd2..9c028163 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -12,7 +12,7 @@ use super::Circuit; /// An iterator over the units at the output of a [Node]. #[derive(Clone, Debug)] -pub struct Units { +pub struct Units { /// Whether to only. output_mode: UnitType, /// The types of the node outputs. @@ -26,25 +26,14 @@ pub struct Units { pub(self) current: usize, /// The amount of linear units yielded. linear_count: usize, + /// A pre-set assignment of units that maps linear output ports to + /// [`CircuitUnit`] ids. + /// + /// The default type is `()`, which assigns new linear ids sequentially. + unit_assigner: LA, } -impl Units { - /// Create a new iterator over the units of a node. - // - // FIXME: Currently this ignores any incoming linear unit labels, and just - // assigns new ids sequentially. - #[inline] - #[allow(unused)] - pub(super) fn new(circuit: &impl Circuit, node: Node, output_mode: UnitType) -> Self { - Self { - output_mode, - node_output_types: circuit.get_optype(node).signature().output, - node, - current: 0, - linear_count: 0, - } - } - +impl Units<()> { /// Create a new iterator over the input units of a circuit. /// /// This iterator will yield all units originating from the circuit's input @@ -59,22 +48,49 @@ impl Units { node: circuit.input(), current: 0, linear_count: 0, + unit_assigner: (), } } +} - /// Add the corresponding ports to the iterator output. +impl Units +where + LA: LinearUnitAssigner, +{ + /// Create a new iterator over the units of a node. + // + // Note that this ignores any incoming linear unit labels, and just assigns + // new unit ids sequentially. #[inline] - pub fn with_ports(self) -> UnitPorts { - UnitPorts { units: self } + #[allow(unused)] + pub(super) fn new( + circuit: &impl Circuit, + node: Node, + output_mode: UnitType, + unit_assigner: LA, + ) -> Self { + Self { + output_mode, + node_output_types: circuit.get_optype(node).signature().output, + node, + current: 0, + linear_count: 0, + unit_assigner, + } } /// Construct an output value to yield. #[inline] - fn make_value(&self, typ: &Type, input_port: Port) -> (CircuitUnit, Type) { + fn make_value(&self, typ: &Type, port: Port) -> (CircuitUnit, Port, Type) { match type_is_linear(typ) { - true => (CircuitUnit::Linear(self.linear_count - 1), typ.clone()), + true => ( + self.unit_assigner.assign(port, self.linear_count - 1), + port, + typ.clone(), + ), false => ( - CircuitUnit::Wire(Wire::new(self.node, input_port)), + CircuitUnit::Wire(Wire::new(self.node, port)), + port, typ.clone(), ), } @@ -82,18 +98,18 @@ impl Units { } impl Iterator for Units { - type Item = (CircuitUnit, Type); + type Item = (CircuitUnit, Port, Type); fn next(&mut self) -> Option { loop { let typ = self.node_output_types.get(self.current)?; - let input_port = Port::new_outgoing(self.current); + let port = Port::new_outgoing(self.current); self.current += 1; if type_is_linear(typ) { self.linear_count += 1; } if self.output_mode.accept(typ) { - return Some(self.make_value(typ, input_port)); + return Some(self.make_value(typ, port)); } } } @@ -109,31 +125,6 @@ impl Iterator for Units { impl FusedIterator for Units {} -/// An iterator over the units of a circuit, including their [`Port`]s. -/// -/// A simple wrapper around [`Units`]. -#[repr(transparent)] -pub struct UnitPorts { - /// The internal Units iterator. - units: Units, -} - -impl Iterator for UnitPorts { - type Item = (CircuitUnit, Port, Type); - - fn next(&mut self) -> Option { - let port = Port::new_outgoing(self.units.current); - let (unit, typ) = self.units.next()?; - Some((unit, port, typ)) - } - - fn size_hint(&self) -> (usize, Option) { - self.units.size_hint() - } -} - -impl FusedIterator for UnitPorts {} - /// What kind of units to iterate over. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub(super) enum UnitType { @@ -160,6 +151,24 @@ impl UnitType { } } +/// A map for assigning linear unit ids to ports. +pub trait LinearUnitAssigner { + /// Assign a linear unit id to an output port. + fn assign(&self, port: Port, unit: usize) -> CircuitUnit; +} + +impl LinearUnitAssigner for () { + fn assign(&self, _port: Port, unit: usize) -> CircuitUnit { + CircuitUnit::Linear(unit) + } +} + +impl<'a> LinearUnitAssigner for &'a Vec { + fn assign(&self, _port: Port, unit: usize) -> CircuitUnit { + CircuitUnit::Linear(self[unit]) + } +} + fn type_is_linear(typ: &Type) -> bool { !TypeBound::Copyable.contains(typ.least_upper_bound()) } diff --git a/src/json/encoder.rs b/src/json/encoder.rs index e087d147..ed11985e 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -49,7 +49,7 @@ impl JsonEncoder { // TODO Throw an error on non-recognized unit types, or just ignore? let mut bit_units = HashMap::new(); let mut qubit_units = HashMap::new(); - for (unit, ty) in circ.units() { + for (unit, _, ty) in circ.units() { if ty == QB_T { let index = vec![qubit_units.len() as i64]; let reg = circuit_json::Register("q".to_string(), index); From c874b84ca86bd47cff6d1eb64557b6972d8b426d Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Mon, 11 Sep 2023 21:05:30 +0100 Subject: [PATCH 08/12] Expand the Command API, and use it --- src/circuit.rs | 19 +--- src/circuit/command.rs | 203 +++++++++++++++++++++++++++--------- src/circuit/units.rs | 112 +++++++++++--------- src/json.rs | 4 +- src/json/encoder.rs | 22 ++-- src/portmatching/pattern.rs | 4 +- 6 files changed, 236 insertions(+), 128 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index bc8b1d6b..169410f0 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -7,7 +7,6 @@ mod units; pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; -use hugr::hugr::NodeType; use hugr::HugrView; pub use hugr::ops::OpType; @@ -79,7 +78,7 @@ pub trait Circuit: HugrView { where Self: Sized, { - Units::new_input(self, UnitType::All) + Units::new_circ_input(self, UnitType::All) } /// Get the linear input units of the circuit and their types. @@ -88,7 +87,7 @@ pub trait Circuit: HugrView { where Self: Sized, { - Units::new_input(self, UnitType::Linear) + Units::new_circ_input(self, UnitType::Linear) } /// Get the non-linear input units of the circuit and their types. @@ -97,7 +96,7 @@ pub trait Circuit: HugrView { where Self: Sized, { - Units::new_input(self, UnitType::NonLinear) + Units::new_circ_input(self, UnitType::NonLinear) } /// Returns the units corresponding to qubits inputs to the circuit. @@ -106,7 +105,7 @@ pub trait Circuit: HugrView { where Self: Sized, { - Units::new_input(self, UnitType::Qubits) + Units::new_circ_input(self, UnitType::Qubits) } /// Returns all the commands in the circuit, in some topological order. @@ -120,16 +119,6 @@ pub trait Circuit: HugrView { // Traverse the circuit in topological order. CommandIterator::new(self) } - - /// Returns the [`NodeType`] of a command. - fn command_nodetype(&self, command: &Command) -> &NodeType { - self.get_nodetype(command.node()) - } - - /// Returns the [`OpType`] of a command. - fn command_optype(&self, command: &Command) -> &OpType { - self.get_optype(command.node()) - } } impl Circuit for T where T: HugrView {} diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 28ab62e0..542cbdad 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -5,9 +5,11 @@ use std::collections::{HashMap, HashSet}; use std::iter::FusedIterator; +use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; use petgraph::visit as pv; +use super::units::{LinearUnit, LinearUnitAssigner, UnitType, Units}; use super::Circuit; pub use hugr::hugr::CircuitUnit; @@ -16,30 +18,109 @@ pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Direction, Node, Port, Wire}; /// An operation applied to specific wires. -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Command { +#[derive(Eq, PartialOrd, Ord, Hash)] +pub struct Command<'circ, Circ> { + /// The circuit. + circ: &'circ Circ, /// The operation node. node: Node, - /// The input units to the operation. - inputs: Vec, - /// The output units to the operation. - outputs: Vec, + /// An assignment of linear units to the node's ports. + linear_units: Vec>, } -impl Command { +impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the node corresponding to this command. pub fn node(&self) -> Node { self.node } + /// Returns the [`NodeType`] of the command. + pub fn nodetype(&self) -> &NodeType { + self.circ.get_nodetype(self.node) + } + + /// Returns the [`OpType`] of the command. + pub fn optype(&self) -> &OpType { + self.circ.get_optype(self.node) + } + /// Returns the output units of this command. - pub fn outputs(&self) -> &Vec { - &self.outputs + pub fn outputs(&self) -> Units<&'_ Self> { + Units::new( + self.circ, + self.node, + Direction::Outgoing, + UnitType::All, + self, + ) + } + + /// Returns the output wires of this command. + pub fn output_wires(&self) -> impl FusedIterator + '_ { + self.outputs() + .map(|(unit, port, _)| (unit, Wire::new(self.node, port))) } /// Returns the output units of this command. - pub fn inputs(&self) -> &Vec { - &self.inputs + pub fn inputs(&self) -> Units<&'_ Self> { + Units::new( + self.circ, + self.node, + Direction::Incoming, + UnitType::All, + self, + ) + } + + /// Returns the number of inputs of this command. + pub fn input_count(&self) -> usize { + self.optype().signature().input_count() + } + + /// Returns the number of outputs of this command. + pub fn output_count(&self) -> usize { + self.optype().signature().output_count() + } +} + +impl<'a, 'circ, Circ> LinearUnitAssigner for &'a Command<'circ, Circ> { + fn assign(&self, port: Port, _linear_count: usize) -> LinearUnit { + self.linear_units + .get(port.index()) + .copied() + .flatten() + .unwrap_or_else(|| { + panic!( + "Could not assign a linear unit to port {port:?} of node {:?}", + self.node + ) + }) + } +} + +impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Command") + .field("circ", &self.circ.name()) + .field("node", &self.node) + .field("linear_units", &self.linear_units) + .finish() + } +} + +impl<'circ, Circ> PartialEq for Command<'circ, Circ> { + fn eq(&self, other: &Self) -> bool { + self.node == other.node && self.linear_units == other.linear_units + } +} + +impl<'circ, Circ> Clone for Command<'circ, Circ> { + fn clone(&self) -> Self { + Self { + circ: self.circ, + node: self.node, + linear_units: self.linear_units.clone(), + } } } @@ -75,15 +156,12 @@ where { /// Create a new iterator over the commands of a circuit. pub(super) fn new(circ: &'circ Circ) -> Self { - // Initialize the linear units from the input's linear ports. - // TODO: Clean this up - let input_node_wires = circ - .node_outputs(circ.input()) - .map(|port| Wire::new(circ.input(), port)); - let wire_unit = input_node_wires - .zip(circ.linear_units()) - .filter_map(|(wire, (unit, _, _))| match unit { - CircuitUnit::Linear(i) => Some((wire, i)), + // Initialize the map assigning linear units to the input's linear + // ports. + let wire_unit = circ + .units() + .filter_map(|(unit, port, _)| match unit { + CircuitUnit::Linear(i) => Some((Wire::new(circ.input(), port), i)), _ => None, }) .collect(); @@ -93,13 +171,19 @@ where circ, nodes, wire_unit, - remaining: circ.node_count() - 2, + // Ignore the input and output nodes, and the root. + remaining: circ.node_count() - 3, } } - /// Process a new node, updating wires in `unit_wires` and returns the - /// command for the node if it's not an input or output. - fn process_node(&mut self, node: Node) -> Option { + /// Process a new node, updating wires in `unit_wires`. + /// + /// Returns the an option with the `linear_units` used to construct a + /// [`Command`], if the node is not an input or output. + /// + /// We don't return the command directly to avoid lifetime issues due to the + /// mutable borrow here. + fn process_node(&mut self, node: Node) -> Option>> { let optype = self.circ.get_optype(node); let sig = optype.signature(); @@ -110,7 +194,7 @@ where // Get the wire corresponding to each input unit. // TODO: Add this to HugrView? - let inputs: Vec<_> = sig + let _inputs: Vec<_> = sig .input_ports() .chain( // add the static input port @@ -158,11 +242,8 @@ where optype.other_port_index(Direction::Outgoing).unwrap(), ))) } - Some(Command { - node, - inputs, - outputs, - }) + + todo!() } /// Returns the linear port on the node that corresponds to the same linear unit. @@ -187,16 +268,19 @@ impl<'circ, Circ> Iterator for CommandIterator<'circ, Circ> where Circ: Circuit, { - type Item = Command; + type Item = Command<'circ, Circ>; #[inline] fn next(&mut self) -> Option { loop { let node = self.nodes.next(&self.circ.as_petgraph())?; - let com = self.process_node(node); - if com.is_some() { + if let Some(linear_units) = self.process_node(node) { self.remaining -= 1; - return com; + return Some(Command { + circ: self.circ, + node, + linear_units, + }); } } } @@ -214,12 +298,21 @@ mod test { use hugr::hugr::views::{HierarchyView, SiblingGraph}; use hugr::ops::OpName; use hugr::HugrView; + use itertools::Itertools; + use std::fmt::Debug; use crate::utils::build_simple_circuit; use crate::T2Op; use super::*; + fn assert_eq_iter(x: impl Iterator, expected: impl IntoIterator) + where + T: PartialEq + Debug, + { + assert_eq!(x.collect_vec(), expected.into_iter().collect_vec()); + } + #[test] fn iterate_commands() { let hugr = build_simple_circuit(2, |circ| { @@ -239,31 +332,37 @@ mod test { let mut commands = CommandIterator::new(&circ); let hadamard = commands.next().unwrap(); - assert_eq!( - circ.command_optype(&hadamard).name().as_str(), - t2op_name(T2Op::H) + assert_eq!(hadamard.optype().name().as_str(), t2op_name(T2Op::H)); + assert_eq_iter( + hadamard.inputs().map(|(u, _, _)| u), + [CircuitUnit::Linear(0)], + ); + assert_eq_iter( + hadamard.outputs().map(|(u, _, _)| u), + [CircuitUnit::Linear(0)], ); - assert_eq!(hadamard.inputs(), &[CircuitUnit::Linear(0)]); - assert_eq!(hadamard.outputs(), &[CircuitUnit::Linear(0)]); let cx = commands.next().unwrap(); - assert_eq!( - circ.command_optype(&cx).name().as_str(), - t2op_name(T2Op::CX) - ); - assert_eq!( - cx.inputs(), - &[CircuitUnit::Linear(0), CircuitUnit::Linear(1)] + assert_eq!(cx.optype().name().as_str(), t2op_name(T2Op::CX)); + assert_eq_iter( + cx.inputs().map(|(unit, _, _)| unit), + [CircuitUnit::Linear(0), CircuitUnit::Linear(1)], ); - assert_eq!( - cx.outputs(), - &[CircuitUnit::Linear(0), CircuitUnit::Linear(1)] + assert_eq_iter( + cx.outputs().map(|(unit, _, _)| unit), + [CircuitUnit::Linear(0), CircuitUnit::Linear(1)], ); let t = commands.next().unwrap(); - assert_eq!(circ.command_optype(&t).name().as_str(), t2op_name(T2Op::T)); - assert_eq!(t.inputs(), &[CircuitUnit::Linear(1)]); - assert_eq!(t.outputs(), &[CircuitUnit::Linear(1)]); + assert_eq!(t.optype().name().as_str(), t2op_name(T2Op::T)); + assert_eq_iter( + t.inputs().map(|(unit, _, _)| unit), + [CircuitUnit::Linear(1)], + ); + assert_eq_iter( + t.outputs().map(|(unit, _, _)| unit), + [CircuitUnit::Linear(1)], + ); assert_eq!(commands.next(), None); } diff --git a/src/circuit/units.rs b/src/circuit/units.rs index 9c028163..fc1949a9 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -6,28 +6,37 @@ use hugr::extension::prelude; use hugr::hugr::CircuitUnit; use hugr::ops::OpTrait; use hugr::types::{Type, TypeBound, TypeRow}; -use hugr::{Node, Port, Wire}; +use hugr::{Direction, Node, Port, Wire}; use super::Circuit; -/// An iterator over the units at the output of a [Node]. +/// A linear unit id, used in [`CircuitUnit::Linear`]. +// TODO: Add this to hugr? +pub type LinearUnit = usize; + +/// An iterator over the units in the input or output boundary of a [Node]. #[derive(Clone, Debug)] pub struct Units { - /// Whether to only. - output_mode: UnitType, - /// The types of the node outputs. + /// Filter over the yielded units. + /// + /// It can be set to ignore non-linear units, only yield qubits, between + /// other options. See [`UnitType`] for more information. + mode: UnitType, + /// The node of the circuit. + node: Node, + /// The direction of the boundary. + direction: Direction, + /// The types of the boundary. // // TODO: We could avoid cloning the TypeRow if `OpType::signature` returned // a reference. - node_output_types: TypeRow, - /// The node of the circuit. - node: Node, - /// The current index in the inputs. - pub(self) current: usize, + type_row: TypeRow, + /// The current index in the type row. + current: usize, /// The amount of linear units yielded. linear_count: usize, - /// A pre-set assignment of units that maps linear output ports to - /// [`CircuitUnit`] ids. + /// A pre-set assignment of units that maps linear ports to + /// [`CircuitUnit::Linear`] ids. /// /// The default type is `()`, which assigns new linear ids sequentially. unit_assigner: LA, @@ -39,13 +48,14 @@ impl Units<()> { /// This iterator will yield all units originating from the circuit's input /// node. #[inline] - pub(super) fn new_input(circuit: &impl Circuit, output_mode: UnitType) -> Self { + pub(super) fn new_circ_input(circuit: &impl Circuit, output_mode: UnitType) -> Self { Self { - output_mode, - node_output_types: circuit + mode: output_mode, + node: circuit.input(), + direction: Direction::Outgoing, + type_row: circuit .get_function_type() .map_or_else(Default::default, |ft| ft.input.clone()), - node: circuit.input(), current: 0, linear_count: 0, unit_assigner: (), @@ -62,17 +72,23 @@ where // Note that this ignores any incoming linear unit labels, and just assigns // new unit ids sequentially. #[inline] - #[allow(unused)] pub(super) fn new( circuit: &impl Circuit, node: Node, + direction: Direction, output_mode: UnitType, unit_assigner: LA, ) -> Self { + let sig = circuit.get_optype(node).signature(); + let type_row = match direction { + Direction::Outgoing => sig.output, + Direction::Incoming => sig.input, + }; Self { - output_mode, - node_output_types: circuit.get_optype(node).signature().output, + mode: output_mode, node, + direction, + type_row, current: 0, linear_count: 0, unit_assigner, @@ -80,50 +96,54 @@ where } /// Construct an output value to yield. + /// + /// Calls [`LinearUnitAssigner::assign`] to assign a linear unit id to the linear ports. + /// Non-linear ports are assigned [`CircuitUnit::Wire`]s. #[inline] fn make_value(&self, typ: &Type, port: Port) -> (CircuitUnit, Port, Type) { - match type_is_linear(typ) { - true => ( - self.unit_assigner.assign(port, self.linear_count - 1), - port, - typ.clone(), - ), - false => ( - CircuitUnit::Wire(Wire::new(self.node, port)), - port, - typ.clone(), - ), - } + let unit = if type_is_linear(typ) { + let linear_unit = self.unit_assigner.assign(port, self.linear_count - 1); + CircuitUnit::Linear(linear_unit) + } else { + match self.direction { + Direction::Outgoing => CircuitUnit::Wire(Wire::new(self.node, port)), + Direction::Incoming => CircuitUnit::Wire(Wire::new(self.node, port)), + } + }; + (unit, port, typ.clone()) } } -impl Iterator for Units { +impl Iterator for Units +where + LA: LinearUnitAssigner, +{ type Item = (CircuitUnit, Port, Type); fn next(&mut self) -> Option { loop { - let typ = self.node_output_types.get(self.current)?; - let port = Port::new_outgoing(self.current); + let typ = self.type_row.get(self.current)?; + let port = Port::new(self.direction, self.current); self.current += 1; if type_is_linear(typ) { self.linear_count += 1; } - if self.output_mode.accept(typ) { + if self.mode.accept(typ) { return Some(self.make_value(typ, port)); } } } fn size_hint(&self) -> (usize, Option) { - let len = self.node_output_types.len() - self.current; - match self.output_mode { + let len = self.type_row.len() - self.current; + match self.mode { UnitType::All => (len, Some(len)), _ => (0, Some(len)), } } } -impl FusedIterator for Units {} +impl FusedIterator for Units where LA: LinearUnitAssigner {} /// What kind of units to iterate over. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] @@ -154,18 +174,16 @@ impl UnitType { /// A map for assigning linear unit ids to ports. pub trait LinearUnitAssigner { /// Assign a linear unit id to an output port. - fn assign(&self, port: Port, unit: usize) -> CircuitUnit; + /// + /// # Parameters + /// - port: The node's port in the node. + /// - linear_count: The number of linear units yielded so far. + fn assign(&self, port: Port, linear_count: usize) -> LinearUnit; } impl LinearUnitAssigner for () { - fn assign(&self, _port: Port, unit: usize) -> CircuitUnit { - CircuitUnit::Linear(unit) - } -} - -impl<'a> LinearUnitAssigner for &'a Vec { - fn assign(&self, _port: Port, unit: usize) -> CircuitUnit { - CircuitUnit::Linear(self[unit]) + fn assign(&self, _port: Port, linear_count: usize) -> LinearUnit { + linear_count } } diff --git a/src/json.rs b/src/json.rs index ff9ebc85..effec43a 100644 --- a/src/json.rs +++ b/src/json.rs @@ -63,8 +63,8 @@ impl TKETDecode for SerialCircuit { fn encode(circ: &impl Circuit) -> Result { let mut encoder = JsonEncoder::new(circ); for com in circ.commands() { - let optype = circ.command_optype(&com); - encoder.add_command(com, optype)?; + let optype = com.optype(); + encoder.add_command(com.clone(), optype)?; } Ok(encoder.finish()) } diff --git a/src/json/encoder.rs b/src/json/encoder.rs index ed11985e..0d0a43de 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -88,7 +88,11 @@ impl JsonEncoder { } /// Add a circuit command to the serialization. - pub fn add_command(&mut self, command: Command, optype: &OpType) -> Result<(), OpConvertError> { + pub fn add_command( + &mut self, + command: Command<'_, C>, + optype: &OpType, + ) -> Result<(), OpConvertError> { // Register any output of the command that can be used as a TKET1 parameter. if self.record_parameters(&command, optype) { // for now all ops that record parameters should be ignored (are @@ -99,8 +103,7 @@ impl JsonEncoder { let (args, params): (Vec, Vec) = command .inputs() - .iter() - .partition_map(|&u| match self.unit_to_register(u) { + .partition_map(|(u, _, _)| match self.unit_to_register(u) { Some(r) => Either::Left(r), None => match u { CircuitUnit::Wire(w) => Either::Right(w), @@ -145,17 +148,16 @@ impl JsonEncoder { /// Record any output of the command that can be used as a TKET1 parameter. /// Returns whether parameters were recorded. /// Associates the output wires with the parameter expression. - fn record_parameters(&mut self, command: &Command, optype: &OpType) -> bool { + fn record_parameters(&mut self, command: &Command<'_, C>, optype: &OpType) -> bool { // Only consider commands where all inputs are parameters. let inputs = command .inputs() - .iter() - .filter_map(|unit| match unit { - CircuitUnit::Wire(wire) => self.parameters.get(wire), + .filter_map(|(unit, _, _)| match unit { + CircuitUnit::Wire(wire) => self.parameters.get(&wire), CircuitUnit::Linear(_) => None, }) .collect_vec(); - if inputs.len() != command.inputs().len() { + if inputs.len() != command.input_count() { return false; } @@ -194,9 +196,9 @@ impl JsonEncoder { } }; - for unit in command.outputs() { + for (unit, _, _) in command.outputs() { if let CircuitUnit::Wire(wire) = unit { - self.parameters.insert(*wire, param.clone()); + self.parameters.insert(wire, param.clone()); } } true diff --git a/src/portmatching/pattern.rs b/src/portmatching/pattern.rs index 0d123f10..58bfa085 100644 --- a/src/portmatching/pattern.rs +++ b/src/portmatching/pattern.rs @@ -39,9 +39,9 @@ impl CircuitPattern { } let mut pattern = Pattern::new(); for cmd in circuit.commands() { - let op = circuit.command_optype(&cmd).clone(); + let op = cmd.optype().clone(); pattern.require(cmd.node(), op.try_into().unwrap()); - for out_offset in 0..cmd.outputs().len() { + for out_offset in 0..cmd.output_count() { let out_offset = Port::new_outgoing(out_offset); for (next_node, in_offset) in circuit.linked_ports(cmd.node(), out_offset) { if circuit.get_optype(next_node).tag() != hugr::ops::OpTag::Output { From c399e7374c853f24a84d1fefcdb0520774e1bb77 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Tue, 12 Sep 2023 00:20:53 +0100 Subject: [PATCH 09/12] Correctly track all the units --- src/circuit/command.rs | 179 ++++++++++++++++++++--------------------- src/circuit/units.rs | 126 +++++++++++++++++++---------- src/json/encoder.rs | 4 + 3 files changed, 176 insertions(+), 133 deletions(-) diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 542cbdad..f1e1d2e0 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -9,7 +9,7 @@ use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; use petgraph::visit as pv; -use super::units::{LinearUnit, LinearUnitAssigner, UnitType, Units}; +use super::units::{LinearUnit, UnitLabeller, UnitType, Units}; use super::Circuit; pub use hugr::hugr::CircuitUnit; @@ -25,7 +25,10 @@ pub struct Command<'circ, Circ> { /// The operation node. node: Node, /// An assignment of linear units to the node's ports. - linear_units: Vec>, + // + // We'll need something more complex if `follow_linear_port` stops being a + // direct map from input to output. + linear_units: Vec, } impl<'circ, Circ: Circuit> Command<'circ, Circ> { @@ -74,7 +77,8 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the number of inputs of this command. pub fn input_count(&self) -> usize { - self.optype().signature().input_count() + let optype = self.optype(); + optype.signature().input_count() + optype.static_input().is_some() as usize } /// Returns the number of outputs of this command. @@ -83,25 +87,33 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { } } -impl<'a, 'circ, Circ> LinearUnitAssigner for &'a Command<'circ, Circ> { - fn assign(&self, port: Port, _linear_count: usize) -> LinearUnit { - self.linear_units - .get(port.index()) - .copied() - .flatten() - .unwrap_or_else(|| { - panic!( - "Could not assign a linear unit to port {port:?} of node {:?}", - self.node - ) - }) +impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { + #[inline] + fn assign_linear(&self, _: Node, port: Port, _linear_count: usize) -> LinearUnit { + *self.linear_units.get(port.index()).unwrap_or_else(|| { + panic!( + "Could not assign a linear unit to port {port:?} of node {:?}", + self.node + ) + }) + } + + #[inline] + fn assign_wire(&self, node: Node, port: Port) -> Option { + match port.direction() { + Direction::Incoming => { + let (from, from_port) = self.circ.linked_ports(node, port).next()?; + Some(Wire::new(from, from_port)) + } + Direction::Outgoing => Some(Wire::new(node, port)), + } } } impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Command") - .field("circ", &self.circ.name()) + .field("circuit name", &self.circ.name()) .field("node", &self.node) .field("linear_units", &self.linear_units) .finish() @@ -134,22 +146,12 @@ pub struct CommandIterator<'circ, Circ> { circ: &'circ Circ, /// Toposorted nodes. nodes: NodeWalker, - /// Last wires for each linear `CircuitUnit`. + /// Last wire for each [`LinearUnit`] in the circuit. wire_unit: HashMap, /// Remaining commands, not counting I/O nodes. remaining: usize, } -impl<'circ, Circ: std::fmt::Debug> std::fmt::Debug for CommandIterator<'circ, Circ> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CommandIterator") - .field("circ", &self.circ) - .field("wire_unit", &self.wire_unit) - .field("remaining", &self.remaining) - .finish() - } -} - impl<'circ, Circ> CommandIterator<'circ, Circ> where Circ: Circuit, @@ -183,73 +185,56 @@ where /// /// We don't return the command directly to avoid lifetime issues due to the /// mutable borrow here. - fn process_node(&mut self, node: Node) -> Option>> { - let optype = self.circ.get_optype(node); - let sig = optype.signature(); - + fn process_node(&mut self, node: Node) -> Option> { // The root node is ignored. if node == self.circ.root() { return None; } - - // Get the wire corresponding to each input unit. - // TODO: Add this to HugrView? - let _inputs: Vec<_> = sig - .input_ports() - .chain( - // add the static input port - optype - .static_input() - // TODO query optype for this port once it is available in hugr. - .map(|_| Port::new_incoming(sig.input.len())), - ) - .filter_map(|port| { - let (from, from_port) = self.circ.linked_ports(node, port).next()?; - let wire = Wire::new(from, from_port); - // Get the unit corresponding to a wire, or return a wire Unit. - match self.wire_unit.remove(&wire) { - Some(unit) => { - if let Some(new_port) = self.follow_linear_port(node, port) { - self.wire_unit.insert(Wire::new(node, new_port), unit); - } - Some(CircuitUnit::Linear(unit)) - } - None => Some(CircuitUnit::Wire(wire)), - } - }) - .collect(); - // The units in `self.wire_units` have been updated. - // Now we can early return if the node should be ignored. - let tag = optype.tag(); + // Inputs and outputs are also ignored. + // The input wire ids are already set in the `wire_unit` map during initialization. + let tag = self.circ.get_optype(node).tag(); if tag == OpTag::Input || tag == OpTag::Output { return None; } - let mut outputs: Vec<_> = sig - .output_ports() - .map(|port| { - let wire = Wire::new(node, port); - match self.wire_unit.get(&wire) { - Some(&unit) => CircuitUnit::Linear(unit), - None => CircuitUnit::Wire(wire), - } - }) - .collect(); - if let OpType::Const(_) = optype { - // add the static output port from a const. - outputs.push(CircuitUnit::Wire(Wire::new( - node, - optype.other_port_index(Direction::Outgoing).unwrap(), - ))) - } - - todo!() + // Collect the linear units passing through this command into the map + // required to construct a `Command`. + // + // Updates the map tracking the last wire of linear units. + let linear_units: Vec<_> = + Units::new(self.circ, node, Direction::Outgoing, UnitType::Linear, ()) + .map(|(unit, port, _)| { + let CircuitUnit::Linear(_) = unit else { + panic!("Expected a linear unit"); + }; + // Find the linear unit id for this port. + let linear_id = self + .follow_linear_port(node, port) + .and_then(|input_port| self.circ.linked_ports(node, input_port).next()) + .and_then(|(from, from_port)| { + // Remove the old wire from the map (if there was one) + self.wire_unit.remove(&Wire::new(from, from_port)) + }) + .unwrap_or({ + // New linear unit found. Assign it a new id. + self.wire_unit.len() + }); + // Update the map tracking the linear units + let new_wire = Wire::new(node, port); + self.wire_unit.insert(new_wire, linear_id); + linear_id + }) + .collect(); + + Some(linear_units) } /// Returns the linear port on the node that corresponds to the same linear unit. /// /// We assume the linear data uses the same port offsets on both sides of the node. /// In the future we may want to have a more general mechanism to handle this. + // + // Note that `Command::linear_units` assumes this behaviour. fn follow_linear_port(&self, node: Node, port: Port) -> Option { let optype = self.circ.get_optype(node); if !optype.port_kind(port)?.is_linear() { @@ -293,24 +278,33 @@ where impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> where Circ: Circuit {} +impl<'circ, Circ: Circuit> std::fmt::Debug for CommandIterator<'circ, Circ> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CommandIterator") + .field("circuit name", &self.circ.name()) + .field("wire_unit", &self.wire_unit) + .field("remaining", &self.remaining) + .finish() + } +} + #[cfg(test)] mod test { use hugr::hugr::views::{HierarchyView, SiblingGraph}; use hugr::ops::OpName; use hugr::HugrView; use itertools::Itertools; - use std::fmt::Debug; use crate::utils::build_simple_circuit; use crate::T2Op; use super::*; - fn assert_eq_iter(x: impl Iterator, expected: impl IntoIterator) - where - T: PartialEq + Debug, - { - assert_eq!(x.collect_vec(), expected.into_iter().collect_vec()); + // We use a macro instead of a function to get the failing line numbers right. + macro_rules! assert_eq_iter { + ($iterable:expr, $expected:expr $(,)?) => { + assert_eq!($iterable.collect_vec(), $expected.into_iter().collect_vec()); + }; } #[test] @@ -330,36 +324,37 @@ mod test { let t2op_name = |op: T2Op| >::into(op).name(); let mut commands = CommandIterator::new(&circ); + assert_eq!(commands.size_hint(), (3, Some(3))); let hadamard = commands.next().unwrap(); assert_eq!(hadamard.optype().name().as_str(), t2op_name(T2Op::H)); - assert_eq_iter( + assert_eq_iter!( hadamard.inputs().map(|(u, _, _)| u), [CircuitUnit::Linear(0)], ); - assert_eq_iter( + assert_eq_iter!( hadamard.outputs().map(|(u, _, _)| u), [CircuitUnit::Linear(0)], ); let cx = commands.next().unwrap(); assert_eq!(cx.optype().name().as_str(), t2op_name(T2Op::CX)); - assert_eq_iter( + assert_eq_iter!( cx.inputs().map(|(unit, _, _)| unit), [CircuitUnit::Linear(0), CircuitUnit::Linear(1)], ); - assert_eq_iter( + assert_eq_iter!( cx.outputs().map(|(unit, _, _)| unit), [CircuitUnit::Linear(0), CircuitUnit::Linear(1)], ); let t = commands.next().unwrap(); assert_eq!(t.optype().name().as_str(), t2op_name(T2Op::T)); - assert_eq_iter( + assert_eq_iter!( t.inputs().map(|(unit, _, _)| unit), [CircuitUnit::Linear(1)], ); - assert_eq_iter( + assert_eq_iter!( t.outputs().map(|(unit, _, _)| unit), [CircuitUnit::Linear(1)], ); diff --git a/src/circuit/units.rs b/src/circuit/units.rs index fc1949a9..6d15a08a 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -5,7 +5,7 @@ use std::iter::FusedIterator; use hugr::extension::prelude; use hugr::hugr::CircuitUnit; use hugr::ops::OpTrait; -use hugr::types::{Type, TypeBound, TypeRow}; +use hugr::types::{EdgeKind, Type, TypeBound, TypeRow}; use hugr::{Direction, Node, Port, Wire}; use super::Circuit; @@ -16,7 +16,7 @@ pub type LinearUnit = usize; /// An iterator over the units in the input or output boundary of a [Node]. #[derive(Clone, Debug)] -pub struct Units { +pub struct Units
    { /// Filter over the yielded units. /// /// It can be set to ignore non-linear units, only yield qubits, between @@ -27,10 +27,7 @@ pub struct Units { /// The direction of the boundary. direction: Direction, /// The types of the boundary. - // - // TODO: We could avoid cloning the TypeRow if `OpType::signature` returned - // a reference. - type_row: TypeRow, + types: TypeRow, /// The current index in the type row. current: usize, /// The amount of linear units yielded. @@ -39,7 +36,7 @@ pub struct Units { /// [`CircuitUnit::Linear`] ids. /// /// The default type is `()`, which assigns new linear ids sequentially. - unit_assigner: LA, + unit_assigner: UL, } impl Units<()> { @@ -53,7 +50,7 @@ impl Units<()> { mode: output_mode, node: circuit.input(), direction: Direction::Outgoing, - type_row: circuit + types: circuit .get_function_type() .map_or_else(Default::default, |ft| ft.input.clone()), current: 0, @@ -63,9 +60,9 @@ impl Units<()> { } } -impl Units +impl
      Units
        where - LA: LinearUnitAssigner, + UL: UnitLabeller, { /// Create a new iterator over the units of a node. // @@ -77,73 +74,100 @@ where node: Node, direction: Direction, output_mode: UnitType, - unit_assigner: LA, + unit_assigner: UL, ) -> Self { - let sig = circuit.get_optype(node).signature(); - let type_row = match direction { - Direction::Outgoing => sig.output, - Direction::Incoming => sig.input, - }; Self { mode: output_mode, node, direction, - type_row, + types: Self::init_types(circuit, node, direction), current: 0, linear_count: 0, unit_assigner, } } + /// Initialize the boudary types. + /// + /// We use a [`TypeRow`] to avoid allocating for simple boundaries, but if + /// any static port is present we create a new owned [`TypeRow`] with them included. + // + // TODO: This is quite hacky, but we need it to accept Const static inputs. + // We should revisit it once this is reworked on the HUGR side. + fn init_types(circuit: &impl Circuit, node: Node, direction: Direction) -> TypeRow { + let optype = circuit.get_optype(node); + let sig = optype.signature(); + let mut types = match direction { + Direction::Outgoing => sig.output, + Direction::Incoming => sig.input, + }; + if let Some(other) = optype.static_input() { + if direction == Direction::Incoming { + types.to_mut().push(other); + } + } + if let Some(EdgeKind::Static(other)) = optype.other_port(direction) { + types.to_mut().push(other); + } + types + } + /// Construct an output value to yield. /// - /// Calls [`LinearUnitAssigner::assign`] to assign a linear unit id to the linear ports. - /// Non-linear ports are assigned [`CircuitUnit::Wire`]s. + /// Calls [`UnitLabeller::assign_linear`] to assign a linear unit id to the linear ports. + /// Non-linear ports are assigned [`CircuitUnit::Wire`]s via [`UnitLabeller::assign_wire`]. #[inline] - fn make_value(&self, typ: &Type, port: Port) -> (CircuitUnit, Port, Type) { + fn make_value(&self, typ: &Type, port: Port) -> Option<(CircuitUnit, Port, Type)> { let unit = if type_is_linear(typ) { - let linear_unit = self.unit_assigner.assign(port, self.linear_count - 1); + let linear_unit = + self.unit_assigner + .assign_linear(self.node, port, self.linear_count - 1); CircuitUnit::Linear(linear_unit) } else { - match self.direction { - Direction::Outgoing => CircuitUnit::Wire(Wire::new(self.node, port)), - Direction::Incoming => CircuitUnit::Wire(Wire::new(self.node, port)), - } + let wire = self.unit_assigner.assign_wire(self.node, port)?; + CircuitUnit::Wire(wire) }; - (unit, port, typ.clone()) + Some((unit, port, typ.clone())) } } -impl Iterator for Units +impl
          Iterator for Units
            where - LA: LinearUnitAssigner, + UL: UnitLabeller, { type Item = (CircuitUnit, Port, Type); fn next(&mut self) -> Option { loop { - let typ = self.type_row.get(self.current)?; + let typ = self.types.get(self.current)?; let port = Port::new(self.direction, self.current); self.current += 1; if type_is_linear(typ) { self.linear_count += 1; } if self.mode.accept(typ) { - return Some(self.make_value(typ, port)); + let val = self.make_value(typ, port); + if val.is_some() { + return val; + } } } } fn size_hint(&self) -> (usize, Option) { - let len = self.type_row.len() - self.current; - match self.mode { - UnitType::All => (len, Some(len)), - _ => (0, Some(len)), + let len = self.types.len() - self.current; + if self.mode == UnitType::All && self.direction == Direction::Outgoing { + (len, Some(len)) + } else { + // Even when yielding every unit, a disconnected input non-linear + // port cannot be assigned a `CircuitUnit::Wire` and so it will be + // skipped. + (0, Some(len)) } } } -impl FusedIterator for Units where LA: LinearUnitAssigner {} +impl
              FusedIterator for Units
                where UL: UnitLabeller {} /// What kind of units to iterate over. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] @@ -171,20 +195,40 @@ impl UnitType { } } -/// A map for assigning linear unit ids to ports. -pub trait LinearUnitAssigner { - /// Assign a linear unit id to an output port. +/// An trait for assigning linear unit ids and wires to ports of a node. +pub trait UnitLabeller { + /// Assign a linear unit id to an port. /// /// # Parameters + /// - node: The node in the circuit. /// - port: The node's port in the node. /// - linear_count: The number of linear units yielded so far. - fn assign(&self, port: Port, linear_count: usize) -> LinearUnit; + fn assign_linear(&self, node: Node, port: Port, linear_count: usize) -> LinearUnit; + + /// Assign a wire to a port, if possible. + /// + /// # Parameters + /// - node: The node in the circuit. + /// - port: The node's port in the node. + fn assign_wire(&self, node: Node, port: Port) -> Option; } -impl LinearUnitAssigner for () { - fn assign(&self, _port: Port, linear_count: usize) -> LinearUnit { +/// The default [`UnitLabeller`] that assigns new linear unit ids +/// sequentially, and only assigns wires to an outgoing ports (as input ports +/// require querying the HUGR for their neighbours). +impl UnitLabeller for () { + #[inline] + fn assign_linear(&self, _: Node, _: Port, linear_count: usize) -> LinearUnit { linear_count } + + #[inline] + fn assign_wire(&self, node: Node, port: Port) -> Option { + match port.direction() { + Direction::Incoming => None, + Direction::Outgoing => Some(Wire::new(node, port)), + } + } } fn type_is_linear(typ: &Type) -> bool { diff --git a/src/json/encoder.rs b/src/json/encoder.rs index 0d0a43de..ca7fe187 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -158,6 +158,10 @@ impl JsonEncoder { }) .collect_vec(); if inputs.len() != command.input_count() { + debug_assert!(!matches!( + optype, + OpType::Const(_) | OpType::LoadConstant(_) + )); return false; } From 4afbae0a18ec99c5eef75b7834276faa415c8a14 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Tue, 12 Sep 2023 00:35:35 +0100 Subject: [PATCH 10/12] Expand the API (more cmd iterating options) --- src/circuit/command.rs | 73 ++++++++++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/src/circuit/command.rs b/src/circuit/command.rs index f1e1d2e0..e00adb3a 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -7,6 +7,7 @@ use std::iter::FusedIterator; use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; +use itertools::Itertools; use petgraph::visit as pv; use super::units::{LinearUnit, UnitLabeller, UnitType, Units}; @@ -33,55 +34,93 @@ pub struct Command<'circ, Circ> { impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the node corresponding to this command. + #[inline] pub fn node(&self) -> Node { self.node } /// Returns the [`NodeType`] of the command. + #[inline] pub fn nodetype(&self) -> &NodeType { self.circ.get_nodetype(self.node) } /// Returns the [`OpType`] of the command. + #[inline] pub fn optype(&self) -> &OpType { self.circ.get_optype(self.node) } - /// Returns the output units of this command. + /// Returns the units of this command in a given direction. + #[inline] + pub fn units(&self, direction: Direction) -> Units<&'_ Self> { + Units::new(self.circ, self.node, direction, UnitType::All, self) + } + + /// Returns the linear units of this command in a given direction. + #[inline] + pub fn linear_units(&self, direction: Direction) -> Units<&'_ Self> { + Units::new(self.circ, self.node, direction, UnitType::Linear, self) + } + + /// Returns the units and wires of this command in a given direction. + #[inline] + pub fn unit_wires( + &self, + direction: Direction, + ) -> impl IntoIterator + '_ { + self.units(direction) + .filter_map(|(unit, port, _)| Some((unit, self.assign_wire(self.node, port)?))) + .collect_vec() + } + + /// Returns the output units of this command. See [`Command::units`]. + #[inline] pub fn outputs(&self) -> Units<&'_ Self> { - Units::new( - self.circ, - self.node, - Direction::Outgoing, - UnitType::All, - self, - ) + self.units(Direction::Outgoing) + } + + /// Returns the linear output units of this command. See [`Command::linear_units`]. + #[inline] + pub fn linear_outputs(&self) -> Units<&'_ Self> { + self.linear_units(Direction::Outgoing) } - /// Returns the output wires of this command. - pub fn output_wires(&self) -> impl FusedIterator + '_ { + /// Returns the output units and wires of this command. See [`Command::unit_wires`]. + #[inline] + pub fn output_wires(&self) -> impl IntoIterator + '_ { + // Specialized version of `self.unit_wires()` that avoids collecting a + // `Vec` since it doesn't need to borrow `self`. self.outputs() .map(|(unit, port, _)| (unit, Wire::new(self.node, port))) } /// Returns the output units of this command. pub fn inputs(&self) -> Units<&'_ Self> { - Units::new( - self.circ, - self.node, - Direction::Incoming, - UnitType::All, - self, - ) + self.units(Direction::Incoming) + } + + /// Returns the linear input units of this command. See [`Command::linear_units`]. + #[inline] + pub fn linear_inputs(&self) -> Units<&'_ Self> { + self.linear_units(Direction::Incoming) + } + + /// Returns the input units and wires of this command. See [`Command::unit_wires`]. + #[inline] + pub fn input_wires(&self) -> impl IntoIterator + '_ { + self.unit_wires(Direction::Incoming) } /// Returns the number of inputs of this command. + #[inline] pub fn input_count(&self) -> usize { let optype = self.optype(); optype.signature().input_count() + optype.static_input().is_some() as usize } /// Returns the number of outputs of this command. + #[inline] pub fn output_count(&self) -> usize { self.optype().signature().output_count() } From ec3becd07972f99239bc276bf2ffb9bbb1ec750f Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 13 Sep 2023 11:11:31 +0100 Subject: [PATCH 11/12] feat: Support arbitrary nodes and directions in `Units` --- src/circuit.rs | 20 ++-- src/circuit/command.rs | 2 +- src/circuit/units.rs | 210 +++++++++++++++++++++++++++++++++-------- src/json/encoder.rs | 2 +- 4 files changed, 181 insertions(+), 53 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index d5c46ec3..076123ea 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -75,38 +75,38 @@ pub trait Circuit: HugrView { /// Get the input units of the circuit and their types. #[inline] - fn units(&self) -> Units<'_> + fn units(&self) -> Units where Self: Sized, { - Units::new(self, UnitType::All) + Units::new_circ_input(self, UnitType::All) } /// Get the linear input units of the circuit and their types. #[inline] - fn linear_units(&self) -> Units<'_> + fn linear_units(&self) -> Units where Self: Sized, { - Units::new(self, UnitType::Linear) + Units::new_circ_input(self, UnitType::Linear) } /// Get the non-linear input units of the circuit and their types. #[inline] - fn nonlinear_units(&self) -> Units<'_> + fn nonlinear_units(&self) -> Units where Self: Sized, { - Units::new(self, UnitType::NonLinear) + Units::new_circ_input(self, UnitType::NonLinear) } /// Returns the units corresponding to qubits inputs to the circuit. #[inline] - fn qubits(&self) -> Units<'_> + fn qubits(&self) -> Units where Self: Sized, { - Units::new(self, UnitType::Qubits) + Units::new_circ_input(self, UnitType::Qubits) } /// Returns all the commands in the circuit, in some topological order. @@ -172,7 +172,7 @@ mod tests { assert_eq!(circ.linear_units().count(), 3); assert_eq!(circ.qubits().count(), 2); - assert!(circ.linear_units().all(|(unit, _)| unit.is_linear())); - assert!(circ.nonlinear_units().all(|(unit, _)| unit.is_wire())); + assert!(circ.linear_units().all(|(unit, _, _)| unit.is_linear())); + assert!(circ.nonlinear_units().all(|(unit, _, _)| unit.is_wire())); } } diff --git a/src/circuit/command.rs b/src/circuit/command.rs index e41e21fa..2be01157 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -68,7 +68,7 @@ where .map(|port| Wire::new(circ.input(), port)); let wire_unit = input_node_wires .zip(circ.linear_units()) - .filter_map(|(wire, (unit, _))| match unit { + .filter_map(|(wire, (unit, _, _))| match unit { CircuitUnit::Linear(i) => Some((wire, i)), _ => None, }) diff --git a/src/circuit/units.rs b/src/circuit/units.rs index bc5d3c30..804f5ffb 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -4,85 +4,177 @@ use std::iter::FusedIterator; use hugr::extension::prelude; use hugr::hugr::CircuitUnit; -use hugr::types::{Type, TypeBound, TypeRow}; -use hugr::{Node, Port, Wire}; +use hugr::ops::OpTrait; +use hugr::types::{EdgeKind, Type, TypeBound, TypeRow}; +use hugr::{Direction, Node, Port, Wire}; use super::Circuit; -/// An iterator over the units of a circuit. -pub struct Units<'a> { - /// Whether to only - output_mode: UnitType, - /// The inputs to the circuit - inputs: Option<&'a TypeRow>, - /// Input node of the circuit - input_node: Node, - /// The current index in the inputs +/// A linear unit id, used in [`CircuitUnit::Linear`]. +// TODO: Add this to hugr? +pub type LinearUnit = usize; + +/// An iterator over the units in the input or output boundary of a [Node]. +#[derive(Clone, Debug)] +pub struct Units
                  { + /// Filter over the yielded units. + /// + /// It can be set to ignore non-linear units, only yield qubits, between + /// other options. See [`UnitType`] for more information. + mode: UnitType, + /// The node of the circuit. + node: Node, + /// The direction of the boundary. + direction: Direction, + /// The types of the boundary. + types: TypeRow, + /// The current index in the type row. current: usize, /// The amount of linear units yielded. linear_count: usize, + /// A pre-set assignment of units that maps linear ports to + /// [`CircuitUnit::Linear`] ids. + /// + /// The default type is `()`, which assigns new linear ids sequentially. + unit_assigner: UL, } -impl<'a> Units<'a> { - /// Create a new iterator over the units of a circuit. - pub(super) fn new(circuit: &'a impl Circuit, output_mode: UnitType) -> Self { +impl Units<()> { + /// Create a new iterator over the input units of a circuit. + /// + /// This iterator will yield all units originating from the circuit's input + /// node. + #[inline] + pub(super) fn new_circ_input(circuit: &impl Circuit, output_mode: UnitType) -> Self { Self { - output_mode, - inputs: circuit.get_function_type().map(|ft| &ft.input), - input_node: circuit.input(), + mode: output_mode, + node: circuit.input(), + direction: Direction::Outgoing, + types: circuit + .get_function_type() + .map_or_else(Default::default, |ft| ft.input.clone()), current: 0, linear_count: 0, + unit_assigner: (), } } +} - /// Construct an output value to yield. - fn make_value(&self, typ: &Type, input_port: Port) -> (CircuitUnit, Type) { - match type_is_linear(typ) { - true => (CircuitUnit::Linear(self.linear_count - 1), typ.clone()), - false => ( - CircuitUnit::Wire(Wire::new(self.input_node, input_port)), - typ.clone(), - ), +impl
                    Units
                      +where + UL: UnitLabeller, +{ + /// Create a new iterator over the units of a node. + // + // Note that this ignores any incoming linear unit labels, and just assigns + // new unit ids sequentially. + #[inline] + #[allow(unused)] + pub(super) fn new( + circuit: &impl Circuit, + node: Node, + direction: Direction, + output_mode: UnitType, + unit_assigner: UL, + ) -> Self { + Self { + mode: output_mode, + node, + direction, + types: Self::init_types(circuit, node, direction), + current: 0, + linear_count: 0, + unit_assigner, } } + + /// Initialize the boundary types. + /// + /// We use a [`TypeRow`] to avoid allocating for simple boundaries, but if + /// any static port is present we create a new owned [`TypeRow`] with them included. + // + // TODO: This is quite hacky, but we need it to accept Const static inputs. + // We should revisit it once this is reworked on the HUGR side. + fn init_types(circuit: &impl Circuit, node: Node, direction: Direction) -> TypeRow { + let optype = circuit.get_optype(node); + let sig = optype.signature(); + let mut types = match direction { + Direction::Outgoing => sig.output, + Direction::Incoming => sig.input, + }; + if let Some(other) = optype.static_input() { + if direction == Direction::Incoming { + types.to_mut().push(other); + } + } + if let Some(EdgeKind::Static(other)) = optype.other_port(direction) { + types.to_mut().push(other); + } + types + } + + /// Construct an output value to yield. + /// + /// Calls [`UnitLabeller::assign_linear`] to assign a linear unit id to the linear ports. + /// Non-linear ports are assigned [`CircuitUnit::Wire`]s via [`UnitLabeller::assign_wire`]. + #[inline] + fn make_value(&self, typ: &Type, port: Port) -> Option<(CircuitUnit, Port, Type)> { + let unit = if type_is_linear(typ) { + let linear_unit = + self.unit_assigner + .assign_linear(self.node, port, self.linear_count - 1); + CircuitUnit::Linear(linear_unit) + } else { + let wire = self.unit_assigner.assign_wire(self.node, port)?; + CircuitUnit::Wire(wire) + }; + Some((unit, port, typ.clone())) + } } -impl<'a> Iterator for Units<'a> { - type Item = (CircuitUnit, Type); +impl
                        Iterator for Units
                          +where + UL: UnitLabeller, +{ + type Item = (CircuitUnit, Port, Type); fn next(&mut self) -> Option { - let inputs = self.inputs?; loop { - let typ = inputs.get(self.current)?; - let input_port = Port::new_outgoing(self.current); + let typ = self.types.get(self.current)?; + let port = Port::new(self.direction, self.current); self.current += 1; if type_is_linear(typ) { self.linear_count += 1; } - if self.output_mode.accept(typ) { - return Some(self.make_value(typ, input_port)); + if self.mode.accept(typ) { + let val = self.make_value(typ, port); + if val.is_some() { + return val; + } } } } fn size_hint(&self) -> (usize, Option) { - let len = self - .inputs - .map(|inputs| inputs.len() - self.current) - .unwrap_or(0); - match self.output_mode { - UnitType::All => (len, Some(len)), - _ => (0, Some(len)), + let len = self.types.len() - self.current; + if self.mode == UnitType::All && self.direction == Direction::Outgoing { + (len, Some(len)) + } else { + // Even when yielding every unit, a disconnected input non-linear + // port cannot be assigned a `CircuitUnit::Wire` and so it will be + // skipped. + (0, Some(len)) } } } -impl<'a> FusedIterator for Units<'a> {} +impl
                            FusedIterator for Units
                              where UL: UnitLabeller {} /// What kind of units to iterate over. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub(super) enum UnitType { /// All units. + #[default] All, /// Only the linear units. Linear, @@ -104,6 +196,42 @@ impl UnitType { } } +/// An trait for assigning linear unit ids and wires to ports of a node. +pub trait UnitLabeller { + /// Assign a linear unit id to an port. + /// + /// # Parameters + /// - node: The node in the circuit. + /// - port: The node's port in the node. + /// - linear_count: The number of linear units yielded so far. + fn assign_linear(&self, node: Node, port: Port, linear_count: usize) -> LinearUnit; + + /// Assign a wire to a port, if possible. + /// + /// # Parameters + /// - node: The node in the circuit. + /// - port: The node's port in the node. + fn assign_wire(&self, node: Node, port: Port) -> Option; +} + +/// The default [`UnitLabeller`] that assigns new linear unit ids +/// sequentially, and only assigns wires to an outgoing ports (as input ports +/// require querying the HUGR for their neighbours). +impl UnitLabeller for () { + #[inline] + fn assign_linear(&self, _: Node, _: Port, linear_count: usize) -> LinearUnit { + linear_count + } + + #[inline] + fn assign_wire(&self, node: Node, port: Port) -> Option { + match port.direction() { + Direction::Incoming => None, + Direction::Outgoing => Some(Wire::new(node, port)), + } + } +} + fn type_is_linear(typ: &Type) -> bool { !TypeBound::Copyable.contains(typ.least_upper_bound()) } diff --git a/src/json/encoder.rs b/src/json/encoder.rs index e087d147..ed11985e 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -49,7 +49,7 @@ impl JsonEncoder { // TODO Throw an error on non-recognized unit types, or just ignore? let mut bit_units = HashMap::new(); let mut qubit_units = HashMap::new(); - for (unit, ty) in circ.units() { + for (unit, _, ty) in circ.units() { if ty == QB_T { let index = vec![qubit_units.len() as i64]; let reg = circuit_json::Register("q".to_string(), index); From c01473a383af01fb06501ca4c88bd6d6b6d92a47 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 20 Sep 2023 14:08:24 +0100 Subject: [PATCH 12/12] Apply review comments --- src/circuit/command.rs | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 2a338c46..ac8b68c2 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -7,9 +7,9 @@ use std::iter::FusedIterator; use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; -use itertools::Itertools; use petgraph::visit as pv; +use super::units::filter::FilteredUnits; use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units}; use super::Circuit; @@ -59,8 +59,8 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the linear units of this command in a given direction. #[inline] - pub fn linear_units(&self, direction: Direction) -> Units<&'_ Self> { - Units::new(self.circ, self.node, direction, self) + pub fn linear_units(&self, direction: Direction) -> FilteredUnits { + Units::new(self.circ, self.node, direction, self).filter_units::() } /// Returns the units and wires of this command in a given direction. @@ -70,8 +70,7 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { direction: Direction, ) -> impl IntoIterator + '_ { self.units(direction) - .filter_map(|(unit, port, _)| Some((unit, self.assign_wire(self.node, port)?))) - .collect_vec() + .filter_map(move |(unit, port, _)| Some((unit, self.assign_wire(self.node, port)?))) } /// Returns the output units of this command. See [`Command::units`]. @@ -82,27 +81,25 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the linear output units of this command. See [`Command::linear_units`]. #[inline] - pub fn linear_outputs(&self) -> Units<&'_ Self> { + pub fn linear_outputs(&self) -> FilteredUnits { self.linear_units(Direction::Outgoing) } /// Returns the output units and wires of this command. See [`Command::unit_wires`]. #[inline] pub fn output_wires(&self) -> impl IntoIterator + '_ { - // Specialized version of `self.unit_wires()` that avoids collecting a - // `Vec` since it doesn't need to borrow `self`. - self.outputs() - .map(|(unit, port, _)| (unit, Wire::new(self.node, port))) + self.unit_wires(Direction::Outgoing) } /// Returns the output units of this command. + #[inline] pub fn inputs(&self) -> Units<&'_ Self> { self.units(Direction::Incoming) } /// Returns the linear input units of this command. See [`Command::linear_units`]. #[inline] - pub fn linear_inputs(&self) -> Units<&'_ Self> { + pub fn linear_inputs(&self) -> FilteredUnits { self.linear_units(Direction::Incoming) } @@ -199,12 +196,11 @@ where pub(super) fn new(circ: &'circ Circ) -> Self { // Initialize the map assigning linear units to the input's linear // ports. + // + // TODO: `with_wires` combinator for `Units`? let wire_unit = circ - .units() - .filter_map(|(unit, port, _)| match unit { - CircuitUnit::Linear(i) => Some((Wire::new(circ.input(), port), i)), - _ => None, - }) + .linear_units() + .map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit)) .collect(); let nodes = pv::Topo::new(&circ.as_petgraph());