diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 2a338c46..ac8b68c2 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -7,9 +7,9 @@ use std::iter::FusedIterator; use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; -use itertools::Itertools; use petgraph::visit as pv; +use super::units::filter::FilteredUnits; use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units}; use super::Circuit; @@ -59,8 +59,8 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the linear units of this command in a given direction. #[inline] - pub fn linear_units(&self, direction: Direction) -> Units<&'_ Self> { - Units::new(self.circ, self.node, direction, self) + pub fn linear_units(&self, direction: Direction) -> FilteredUnits { + Units::new(self.circ, self.node, direction, self).filter_units::() } /// Returns the units and wires of this command in a given direction. @@ -70,8 +70,7 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { direction: Direction, ) -> impl IntoIterator + '_ { self.units(direction) - .filter_map(|(unit, port, _)| Some((unit, self.assign_wire(self.node, port)?))) - .collect_vec() + .filter_map(move |(unit, port, _)| Some((unit, self.assign_wire(self.node, port)?))) } /// Returns the output units of this command. See [`Command::units`]. @@ -82,27 +81,25 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the linear output units of this command. See [`Command::linear_units`]. #[inline] - pub fn linear_outputs(&self) -> Units<&'_ Self> { + pub fn linear_outputs(&self) -> FilteredUnits { self.linear_units(Direction::Outgoing) } /// Returns the output units and wires of this command. See [`Command::unit_wires`]. #[inline] pub fn output_wires(&self) -> impl IntoIterator + '_ { - // Specialized version of `self.unit_wires()` that avoids collecting a - // `Vec` since it doesn't need to borrow `self`. - self.outputs() - .map(|(unit, port, _)| (unit, Wire::new(self.node, port))) + self.unit_wires(Direction::Outgoing) } /// Returns the output units of this command. + #[inline] pub fn inputs(&self) -> Units<&'_ Self> { self.units(Direction::Incoming) } /// Returns the linear input units of this command. See [`Command::linear_units`]. #[inline] - pub fn linear_inputs(&self) -> Units<&'_ Self> { + pub fn linear_inputs(&self) -> FilteredUnits { self.linear_units(Direction::Incoming) } @@ -199,12 +196,11 @@ where pub(super) fn new(circ: &'circ Circ) -> Self { // Initialize the map assigning linear units to the input's linear // ports. + // + // TODO: `with_wires` combinator for `Units`? let wire_unit = circ - .units() - .filter_map(|(unit, port, _)| match unit { - CircuitUnit::Linear(i) => Some((Wire::new(circ.input(), port), i)), - _ => None, - }) + .linear_units() + .map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit)) .collect(); let nodes = pv::Topo::new(&circ.as_petgraph());