diff --git a/compile-matcher/src/main.rs b/compile-matcher/src/main.rs index ed49a0d8..42c4a757 100644 --- a/compile-matcher/src/main.rs +++ b/compile-matcher/src/main.rs @@ -2,9 +2,6 @@ use std::fs; use std::path::Path; use clap::Parser; -use hugr::hugr::views::{HierarchyView, SiblingGraph}; -use hugr::ops::handle::DfgID; -use hugr::HugrView; use itertools::Itertools; use tket2::json::load_tk1_json_file; @@ -65,9 +62,8 @@ fn main() { let patterns = all_circs .iter() .filter_map(|circ| { - let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(circ, circ.root()); // Fail silently on empty or disconnected patterns - CircuitPattern::try_from_circuit(&circ).ok() + CircuitPattern::try_from_circuit(circ).ok() }) .collect_vec(); println!("Loaded {} patterns.", patterns.len()); diff --git a/src/circuit.rs b/src/circuit.rs index 030b0ff6..ff3eb654 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -7,7 +7,6 @@ mod units; pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; -use hugr::hugr::NodeType; use hugr::HugrView; pub use hugr::ops::OpType; @@ -119,16 +118,6 @@ pub trait Circuit: HugrView { // Traverse the circuit in topological order. CommandIterator::new(self) } - - /// Returns the [`NodeType`] of a command. - fn command_nodetype(&self, command: &Command) -> &NodeType { - self.get_nodetype(command.node()) - } - - /// Returns the [`OpType`] of a command. - fn command_optype(&self, command: &Command) -> &OpType { - self.get_optype(command.node()) - } } impl Circuit for T where T: HugrView {} diff --git a/src/circuit/command.rs b/src/circuit/command.rs index a9dae464..ac8b68c2 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -2,11 +2,15 @@ //! //! A [`Command`] is an operation applied to an specific wires, possibly identified by their index in the circuit's input vector. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::iter::FusedIterator; +use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; +use petgraph::visit as pv; +use super::units::filter::FilteredUnits; +use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units}; use super::Circuit; pub use hugr::hugr::CircuitUnit; @@ -15,44 +19,173 @@ pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Direction, Node, Port, Wire}; /// An operation applied to specific wires. -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Command { +#[derive(Eq, PartialOrd, Ord, Hash)] +pub struct Command<'circ, Circ> { + /// The circuit. + circ: &'circ Circ, /// The operation node. node: Node, - /// The input units to the operation. - inputs: Vec, - /// The output units to the operation. - outputs: Vec, + /// An assignment of linear units to the node's ports. + // + // We'll need something more complex if `follow_linear_port` stops being a + // direct map from input to output. + linear_units: Vec, } -impl Command { +impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the node corresponding to this command. + #[inline] pub fn node(&self) -> Node { self.node } - /// Returns the output units of this command. - pub fn outputs(&self) -> &Vec { - &self.outputs + /// Returns the [`NodeType`] of the command. + #[inline] + pub fn nodetype(&self) -> &NodeType { + self.circ.get_nodetype(self.node) + } + + /// Returns the [`OpType`] of the command. + #[inline] + pub fn optype(&self) -> &OpType { + self.circ.get_optype(self.node) + } + + /// Returns the units of this command in a given direction. + #[inline] + pub fn units(&self, direction: Direction) -> Units<&'_ Self> { + Units::new(self.circ, self.node, direction, self) + } + + /// Returns the linear units of this command in a given direction. + #[inline] + pub fn linear_units(&self, direction: Direction) -> FilteredUnits { + Units::new(self.circ, self.node, direction, self).filter_units::() + } + + /// Returns the units and wires of this command in a given direction. + #[inline] + pub fn unit_wires( + &self, + direction: Direction, + ) -> impl IntoIterator + '_ { + self.units(direction) + .filter_map(move |(unit, port, _)| Some((unit, self.assign_wire(self.node, port)?))) + } + + /// Returns the output units of this command. See [`Command::units`]. + #[inline] + pub fn outputs(&self) -> Units<&'_ Self> { + self.units(Direction::Outgoing) + } + + /// Returns the linear output units of this command. See [`Command::linear_units`]. + #[inline] + pub fn linear_outputs(&self) -> FilteredUnits { + self.linear_units(Direction::Outgoing) + } + + /// Returns the output units and wires of this command. See [`Command::unit_wires`]. + #[inline] + pub fn output_wires(&self) -> impl IntoIterator + '_ { + self.unit_wires(Direction::Outgoing) } /// Returns the output units of this command. - pub fn inputs(&self) -> &Vec { - &self.inputs + #[inline] + pub fn inputs(&self) -> Units<&'_ Self> { + self.units(Direction::Incoming) + } + + /// Returns the linear input units of this command. See [`Command::linear_units`]. + #[inline] + pub fn linear_inputs(&self) -> FilteredUnits { + self.linear_units(Direction::Incoming) + } + + /// Returns the input units and wires of this command. See [`Command::unit_wires`]. + #[inline] + pub fn input_wires(&self) -> impl IntoIterator + '_ { + self.unit_wires(Direction::Incoming) + } + + /// Returns the number of inputs of this command. + #[inline] + pub fn input_count(&self) -> usize { + let optype = self.optype(); + optype.signature().input_count() + optype.static_input().is_some() as usize + } + + /// Returns the number of outputs of this command. + #[inline] + pub fn output_count(&self) -> usize { + self.optype().signature().output_count() } } +impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { + #[inline] + fn assign_linear(&self, _: Node, port: Port, _linear_count: usize) -> LinearUnit { + *self.linear_units.get(port.index()).unwrap_or_else(|| { + panic!( + "Could not assign a linear unit to port {port:?} of node {:?}", + self.node + ) + }) + } + + #[inline] + fn assign_wire(&self, node: Node, port: Port) -> Option { + match port.direction() { + Direction::Incoming => { + let (from, from_port) = self.circ.linked_ports(node, port).next()?; + Some(Wire::new(from, from_port)) + } + Direction::Outgoing => Some(Wire::new(node, port)), + } + } +} + +impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Command") + .field("circuit name", &self.circ.name()) + .field("node", &self.node) + .field("linear_units", &self.linear_units) + .finish() + } +} + +impl<'circ, Circ> PartialEq for Command<'circ, Circ> { + fn eq(&self, other: &Self) -> bool { + self.node == other.node && self.linear_units == other.linear_units + } +} + +impl<'circ, Circ> Clone for Command<'circ, Circ> { + fn clone(&self) -> Self { + Self { + circ: self.circ, + node: self.node, + linear_units: self.linear_units.clone(), + } + } +} + +/// A non-borrowing topological walker over the nodes of a circuit. +type NodeWalker = pv::Topo>; + /// An iterator over the commands of a circuit. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CommandIterator<'circ, Circ> { - /// The circuit + /// The circuit. circ: &'circ Circ, - /// Toposorted nodes - nodes: Vec, - /// Current element in `nodes` - current: usize, - /// Last wires for each linear `CircuitUnit` + /// Toposorted nodes. + nodes: NodeWalker, + /// Last wire for each [`LinearUnit`] in the circuit. wire_unit: HashMap, + /// Remaining commands, not counting I/O nodes. + remaining: usize, } impl<'circ, Circ> CommandIterator<'circ, Circ> @@ -61,100 +194,85 @@ where { /// Create a new iterator over the commands of a circuit. pub(super) fn new(circ: &'circ Circ) -> Self { - // Initialize the linear units from the input's linear ports. - // TODO: Clean this up - let input_node_wires = circ - .node_outputs(circ.input()) - .map(|port| Wire::new(circ.input(), port)); - let wire_unit = input_node_wires - .zip(circ.linear_units()) - .map(|(wire, (linear_unit, _, _))| (wire, linear_unit)) + // Initialize the map assigning linear units to the input's linear + // ports. + // + // TODO: `with_wires` combinator for `Units`? + let wire_unit = circ + .linear_units() + .map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit)) .collect(); - let nodes = petgraph::algo::toposort(&circ.as_petgraph(), None).unwrap(); + let nodes = pv::Topo::new(&circ.as_petgraph()); Self { circ, nodes, - current: 0, wire_unit, + // Ignore the input and output nodes, and the root. + remaining: circ.node_count() - 3, } } - /// Process a new node, updating wires in `unit_wires` and returns the - /// command for the node if it's not an input or output. - fn process_node(&mut self, node: Node) -> Option { - let optype = self.circ.get_optype(node); - let sig = optype.signature(); - + /// Process a new node, updating wires in `unit_wires`. + /// + /// Returns the an option with the `linear_units` used to construct a + /// [`Command`], if the node is not an input or output. + /// + /// We don't return the command directly to avoid lifetime issues due to the + /// mutable borrow here. + fn process_node(&mut self, node: Node) -> Option> { // The root node is ignored. if node == self.circ.root() { return None; } - - // Get the wire corresponding to each input unit. - // TODO: Add this to HugrView? - let inputs: Vec<_> = sig - .input_ports() - .chain( - // add the static input port - optype - .static_input() - // TODO query optype for this port once it is available in hugr. - .map(|_| Port::new_incoming(sig.input.len())), - ) - .filter_map(|port| { - let (from, from_port) = self.circ.linked_ports(node, port).next()?; - let wire = Wire::new(from, from_port); - // Get the unit corresponding to a wire, or return a wire Unit. - match self.wire_unit.remove(&wire) { - Some(unit) => { - if let Some(new_port) = self.follow_linear_port(node, port) { - self.wire_unit.insert(Wire::new(node, new_port), unit); - } - Some(CircuitUnit::Linear(unit)) - } - None => Some(CircuitUnit::Wire(wire)), - } - }) - .collect(); - // The units in `self.wire_units` have been updated. - // Now we can early return if the node should be ignored. - let tag = optype.tag(); + // Inputs and outputs are also ignored. + // The input wire ids are already set in the `wire_unit` map during initialization. + let tag = self.circ.get_optype(node).tag(); if tag == OpTag::Input || tag == OpTag::Output { return None; } - let mut outputs: Vec<_> = sig - .output_ports() - .map(|port| { - let wire = Wire::new(node, port); - match self.wire_unit.get(&wire) { - Some(&unit) => CircuitUnit::Linear(unit), - None => CircuitUnit::Wire(wire), - } - }) - .collect(); - if let OpType::Const(_) = optype { - // add the static output port from a const. - outputs.push(CircuitUnit::Wire(Wire::new( - node, - optype.other_port_index(Direction::Outgoing).unwrap(), - ))) - } - Some(Command { - node, - inputs, - outputs, - }) + // Collect the linear units passing through this command into the map + // required to construct a `Command`. + // + // Updates the map tracking the last wire of linear units. + let linear_units: Vec<_> = + Units::new(self.circ, node, Direction::Outgoing, DefaultUnitLabeller) + .filter_units::() + .map(|(_, port, _)| { + // Find the linear unit id for this port. + let linear_id = self + .follow_linear_port(node, port) + .and_then(|input_port| self.circ.linked_ports(node, input_port).next()) + .and_then(|(from, from_port)| { + // Remove the old wire from the map (if there was one) + self.wire_unit.remove(&Wire::new(from, from_port)) + }) + .unwrap_or({ + // New linear unit found. Assign it a new id. + self.wire_unit.len() + }); + // Update the map tracking the linear units + let new_wire = Wire::new(node, port); + self.wire_unit.insert(new_wire, linear_id); + linear_id + }) + .collect(); + + Some(linear_units) } + /// Returns the linear port on the node that corresponds to the same linear unit. + /// + /// We assume the linear data uses the same port offsets on both sides of the node. + /// In the future we may want to have a more general mechanism to handle this. + // + // Note that `Command::linear_units` assumes this behaviour. fn follow_linear_port(&self, node: Node, port: Port) -> Option { let optype = self.circ.get_optype(node); if !optype.port_kind(port)?.is_linear() { return None; } - // TODO: We assume the linear data uses the same port offsets on both sides of the node. - // In the future we may want to have a more general mechanism to handle this. let other_port = Port::new(port.direction().reverse(), port.index()); if optype.port_kind(other_port) == optype.port_kind(port) { Some(other_port) @@ -168,41 +286,60 @@ impl<'circ, Circ> Iterator for CommandIterator<'circ, Circ> where Circ: Circuit, { - type Item = Command; + type Item = Command<'circ, Circ>; + #[inline] fn next(&mut self) -> Option { loop { - if self.current == self.nodes.len() { - return None; - } - let node = self.nodes[self.current]; - let com = self.process_node(node); - self.current += 1; - if com.is_some() { - return com; + let node = self.nodes.next(&self.circ.as_petgraph())?; + if let Some(linear_units) = self.process_node(node) { + self.remaining -= 1; + return Some(Command { + circ: self.circ, + node, + linear_units, + }); } } } #[inline] fn size_hint(&self) -> (usize, Option) { - (0, Some(self.nodes.len() - self.current)) + (self.remaining, Some(self.remaining)) } } impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> where Circ: Circuit {} +impl<'circ, Circ: Circuit> std::fmt::Debug for CommandIterator<'circ, Circ> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CommandIterator") + .field("circuit name", &self.circ.name()) + .field("wire_unit", &self.wire_unit) + .field("remaining", &self.remaining) + .finish() + } +} + #[cfg(test)] mod test { use hugr::hugr::views::{HierarchyView, SiblingGraph}; use hugr::ops::OpName; use hugr::HugrView; + use itertools::Itertools; use crate::utils::build_simple_circuit; use crate::T2Op; use super::*; + // We use a macro instead of a function to get the failing line numbers right. + macro_rules! assert_eq_iter { + ($iterable:expr, $expected:expr $(,)?) => { + assert_eq!($iterable.collect_vec(), $expected.into_iter().collect_vec()); + }; + } + #[test] fn iterate_commands() { let hugr = build_simple_circuit(2, |circ| { @@ -220,33 +357,40 @@ mod test { let t2op_name = |op: T2Op| >::into(op).name(); let mut commands = CommandIterator::new(&circ); + assert_eq!(commands.size_hint(), (3, Some(3))); let hadamard = commands.next().unwrap(); - assert_eq!( - circ.command_optype(&hadamard).name().as_str(), - t2op_name(T2Op::H) + assert_eq!(hadamard.optype().name().as_str(), t2op_name(T2Op::H)); + assert_eq_iter!( + hadamard.inputs().map(|(u, _, _)| u), + [CircuitUnit::Linear(0)], + ); + assert_eq_iter!( + hadamard.outputs().map(|(u, _, _)| u), + [CircuitUnit::Linear(0)], ); - assert_eq!(hadamard.inputs(), &[CircuitUnit::Linear(0)]); - assert_eq!(hadamard.outputs(), &[CircuitUnit::Linear(0)]); let cx = commands.next().unwrap(); - assert_eq!( - circ.command_optype(&cx).name().as_str(), - t2op_name(T2Op::CX) - ); - assert_eq!( - cx.inputs(), - &[CircuitUnit::Linear(0), CircuitUnit::Linear(1)] + assert_eq!(cx.optype().name().as_str(), t2op_name(T2Op::CX)); + assert_eq_iter!( + cx.inputs().map(|(unit, _, _)| unit), + [CircuitUnit::Linear(0), CircuitUnit::Linear(1)], ); - assert_eq!( - cx.outputs(), - &[CircuitUnit::Linear(0), CircuitUnit::Linear(1)] + assert_eq_iter!( + cx.outputs().map(|(unit, _, _)| unit), + [CircuitUnit::Linear(0), CircuitUnit::Linear(1)], ); let t = commands.next().unwrap(); - assert_eq!(circ.command_optype(&t).name().as_str(), t2op_name(T2Op::T)); - assert_eq!(t.inputs(), &[CircuitUnit::Linear(1)]); - assert_eq!(t.outputs(), &[CircuitUnit::Linear(1)]); + assert_eq!(t.optype().name().as_str(), t2op_name(T2Op::T)); + assert_eq_iter!( + t.inputs().map(|(unit, _, _)| unit), + [CircuitUnit::Linear(1)], + ); + assert_eq_iter!( + t.outputs().map(|(unit, _, _)| unit), + [CircuitUnit::Linear(1)], + ); assert_eq!(commands.next(), None); } diff --git a/src/circuit/units.rs b/src/circuit/units.rs index ffb96ab2..222705a1 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -81,7 +81,6 @@ where // Note that this ignores any incoming linear unit labels, and just assigns // new unit ids sequentially. #[inline] - #[allow(unused)] pub(super) fn new( circuit: &impl Circuit, node: Node, diff --git a/src/json.rs b/src/json.rs index f197918f..37ddbb59 100644 --- a/src/json.rs +++ b/src/json.rs @@ -67,8 +67,8 @@ impl TKETDecode for SerialCircuit { fn encode(circ: &impl Circuit) -> Result { let mut encoder = JsonEncoder::new(circ); for com in circ.commands() { - let optype = circ.command_optype(&com); - encoder.add_command(com, optype)?; + let optype = com.optype(); + encoder.add_command(com.clone(), optype)?; } Ok(encoder.finish()) } diff --git a/src/json/encoder.rs b/src/json/encoder.rs index ed11985e..ca7fe187 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -88,7 +88,11 @@ impl JsonEncoder { } /// Add a circuit command to the serialization. - pub fn add_command(&mut self, command: Command, optype: &OpType) -> Result<(), OpConvertError> { + pub fn add_command( + &mut self, + command: Command<'_, C>, + optype: &OpType, + ) -> Result<(), OpConvertError> { // Register any output of the command that can be used as a TKET1 parameter. if self.record_parameters(&command, optype) { // for now all ops that record parameters should be ignored (are @@ -99,8 +103,7 @@ impl JsonEncoder { let (args, params): (Vec, Vec) = command .inputs() - .iter() - .partition_map(|&u| match self.unit_to_register(u) { + .partition_map(|(u, _, _)| match self.unit_to_register(u) { Some(r) => Either::Left(r), None => match u { CircuitUnit::Wire(w) => Either::Right(w), @@ -145,17 +148,20 @@ impl JsonEncoder { /// Record any output of the command that can be used as a TKET1 parameter. /// Returns whether parameters were recorded. /// Associates the output wires with the parameter expression. - fn record_parameters(&mut self, command: &Command, optype: &OpType) -> bool { + fn record_parameters(&mut self, command: &Command<'_, C>, optype: &OpType) -> bool { // Only consider commands where all inputs are parameters. let inputs = command .inputs() - .iter() - .filter_map(|unit| match unit { - CircuitUnit::Wire(wire) => self.parameters.get(wire), + .filter_map(|(unit, _, _)| match unit { + CircuitUnit::Wire(wire) => self.parameters.get(&wire), CircuitUnit::Linear(_) => None, }) .collect_vec(); - if inputs.len() != command.inputs().len() { + if inputs.len() != command.input_count() { + debug_assert!(!matches!( + optype, + OpType::Const(_) | OpType::LoadConstant(_) + )); return false; } @@ -194,9 +200,9 @@ impl JsonEncoder { } }; - for unit in command.outputs() { + for (unit, _, _) in command.outputs() { if let CircuitUnit::Wire(wire) = unit { - self.parameters.insert(*wire, param.clone()); + self.parameters.insert(wire, param.clone()); } } true diff --git a/src/ops.rs b/src/ops.rs index 86aa19b4..782fbd55 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -46,6 +46,7 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2" )] #[cfg_attr(feature = "pyo3", pyclass)] #[allow(missing_docs)] +#[non_exhaustive] /// Simple enum of tket 2 quantum operations. pub enum T2Op { H, diff --git a/src/portmatching/pattern.rs b/src/portmatching/pattern.rs index e54374d0..cf97d739 100644 --- a/src/portmatching/pattern.rs +++ b/src/portmatching/pattern.rs @@ -39,9 +39,9 @@ impl CircuitPattern { } let mut pattern = Pattern::new(); for cmd in circuit.commands() { - let op = circuit.command_optype(&cmd).clone(); + let op = cmd.optype().clone(); pattern.require(cmd.node(), op.try_into().unwrap()); - for out_offset in 0..cmd.outputs().len() { + for out_offset in 0..cmd.output_count() { let out_offset = Port::new_outgoing(out_offset); for (next_node, in_offset) in circuit.linked_ports(cmd.node(), out_offset) { if circuit.get_optype(next_node).tag() != hugr::ops::OpTag::Output { diff --git a/taso-optimiser/Cargo.toml b/taso-optimiser/Cargo.toml index eb862e14..787f5654 100644 --- a/taso-optimiser/Cargo.toml +++ b/taso-optimiser/Cargo.toml @@ -11,4 +11,9 @@ serde_json = "1.0" tket2 = { path = "../", features = ["portmatching"] } quantinuum-hugr = { workspace = true } itertools = { workspace = true } -tket-json-rs = { workspace = true } \ No newline at end of file +tket-json-rs = { workspace = true } +peak_alloc = { version = "0.2.0", optional = true } + +[features] +default = ["peak_alloc"] +peak_alloc = ["dep:peak_alloc"] \ No newline at end of file diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index f73bd12d..09b022a1 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -8,6 +8,13 @@ use tket2::{ }; use tket_json_rs::circuit_json::SerialCircuit; +#[cfg(feature = "peak_alloc")] +use peak_alloc::PeakAlloc; + +#[cfg(feature = "peak_alloc")] +#[global_allocator] +static PEAK_ALLOC: PeakAlloc = PeakAlloc; + /// Optimise circuits using Quartz-generated ECCs. /// /// Quartz: @@ -91,5 +98,8 @@ fn main() { println!("Saving result"); save_tk1_json_file(output_path, &opt_circ).unwrap(); + #[cfg(feature = "peak_alloc")] + println!("Peak memory usage: {} GB", PEAK_ALLOC.peak_usage_as_gb()); + println!("Done.") }