diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 9bad28ac..b1cb390e 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -2,6 +2,7 @@ //! //! A [`Command`] is an operation applied to an specific wires, possibly identified by their index in the circuit's input vector. +use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::iter::FusedIterator; @@ -242,6 +243,24 @@ pub struct CommandIterator<'circ, Circ> { wire_unit: HashMap, /// Remaining commands, not counting I/O nodes. remaining: usize, + /// Delayed output of constant and load const nodes. Contains nodes that + /// haven't been yielded yet. + /// + /// We only yield them as Commands when their consumers require them. + delayed_consts: HashSet, + /// Nodes with delayed predecessors. + /// + /// Each node is associated with the number of predecessors that are present + /// in `delayed_consts`. + /// + /// This map is used for performance, to avoid checking the neighbours vs + /// the `delayed_consts` set for each processed node. + delayed_consumers: HashMap, + /// The next node to be processed. + /// + /// This node was produced by the last call to `nodes.next()`, but we had to + /// yield some delayed const nodes before it. + delayed_node: Option, } impl<'circ, Circ> CommandIterator<'circ, Circ> @@ -266,6 +285,67 @@ where wire_unit, // Ignore the input and output nodes, and the root. remaining: circ.node_count() - 3, + delayed_consts: HashSet::new(), + delayed_consumers: HashMap::new(), + delayed_node: None, + } + } + + /// Returns the next node to be processed. + /// + /// If the next node in the topological order is a constant or load const node, + /// delay it until its consumers are processed. + fn next_node(&mut self) -> Option { + let node = self + .delayed_node + .take() + .or_else(|| self.nodes.next(&self.circ.as_petgraph()))?; + + // If this node is a constant or load const node, delay it. + let tag = self.circ.get_optype(node).tag(); + if tag == OpTag::Const || tag == OpTag::LoadConst { + self.delayed_consts.insert(node); + for consumer in self.circ.output_neighbours(node) { + *self.delayed_consumers.entry(consumer).or_default() += 1; + } + return self.next_node(); + } + + // Check if we have any delayed const nodes that are consumed by this node. + match self.delayed_consumers.contains_key(&node) { + true => { + let delayed = self.next_delayed_node(node); + self.delayed_consts.remove(&delayed); + for consumer in self.circ.output_neighbours(delayed) { + let Entry::Occupied(mut entry) = self.delayed_consumers.entry(consumer) else { + panic!("Delayed node consumer was not in delayed_consumers. Delayed node: {delayed:?}, consumer: {consumer:?}."); + }; + *entry.get_mut() -= 1; + if *entry.get() == 0 { + entry.remove(); + } + } + self.delayed_node = Some(node); + Some(delayed) + } + false => Some(node), + } + } + + /// Given a node with delayed predecessors, returns one of those predecessors. + fn next_delayed_node(&mut self, consumer: Node) -> Node { + let Some(delayed_pred) = self + .circ + .input_neighbours(consumer) + .find(|k| self.delayed_consts.contains(k)) + else { + panic!("Could not find a delayed predecessor for node {consumer:?}."); + }; + + // Only output this node if it doesn't require any other delayed predecessors. + match self.delayed_consumers.contains_key(&delayed_pred) { + true => self.next_delayed_node(delayed_pred), + false => delayed_pred, } } @@ -347,7 +427,8 @@ where #[inline] fn next(&mut self) -> Option { loop { - let node = self.nodes.next(&self.circ.as_petgraph())?; + let node = self.next_node()?; + // Process the node, returning a command if it's not an input or output. if let Some(linear_units) = self.process_node(node) { self.remaining -= 1; return Some(Command {