From f42746f798f7000849c4019c0b8e914e2ddb3257 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Thu, 5 Oct 2023 18:33:53 +0200 Subject: [PATCH] fix: Support empty wires in CircuitChunks (#172) Now TASO optimizations that result in identity wires won't panic when reassembling the split chunks. --- src/passes/chunks.rs | 268 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 223 insertions(+), 45 deletions(-) diff --git a/src/passes/chunks.rs b/src/passes/chunks.rs index 768039b1..a35c50b4 100644 --- a/src/passes/chunks.rs +++ b/src/passes/chunks.rs @@ -6,12 +6,13 @@ use std::collections::HashMap; use std::mem; use std::ops::{Index, IndexMut}; +use derive_more::From; use hugr::builder::{Container, FunctionBuilder}; use hugr::extension::ExtensionSet; use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::views::sibling_subgraph::ConvexChecker; use hugr::hugr::views::{HierarchyView, SiblingGraph, SiblingSubgraph}; -use hugr::hugr::{HugrError, NodeMetadata}; +use hugr::hugr::{HugrError, NodeMetadata, PortIndex}; use hugr::ops::handle::DataflowParentID; use hugr::ops::OpType; use hugr::types::{FunctionType, Signature}; @@ -34,7 +35,8 @@ use tket_json_rs::circuit_json::SerialCircuit; /// /// When reassembling the circuit, the input/output wires of each chunk are /// re-linked by matching these identifiers. -pub type ChunkConnection = Wire; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)] +pub struct ChunkConnection(Wire); /// A chunk of a circuit. #[derive(Debug, Clone)] @@ -81,13 +83,13 @@ impl Chunk { .exactly_one() .ok() .unwrap(); - Wire::new(out_node, out_port) + Wire::new(out_node, out_port).into() }) .collect(); let outputs = subgraph .outgoing_ports() .iter() - .map(|&(node, port)| Wire::new(node, port)) + .map(|&(node, port)| Wire::new(node, port).into()) .collect(); Self { circ: extracted, @@ -97,40 +99,110 @@ impl Chunk { } /// Insert the chunk back into a circuit. - // - // TODO: The new chunk may have input ports directly connected to outputs. We have to take care of those. - #[allow(clippy::type_complexity)] pub(self) fn insert(&self, circ: &mut impl HugrMut, root: Node) -> ChunkInsertResult { + if self.circ.children(self.circ.root()).nth(2).is_none() { + // The chunk is empty. We don't need to insert anything. + return self.empty_chunk_insert_result(); + } + + let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap(); let chunk_sg: SiblingGraph<'_, DataflowParentID> = SiblingGraph::try_new(&self.circ, self.circ.root()).unwrap(); + // Insert the chunk circuit into the original circuit. let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&chunk_sg) - .expect("The chunk circuit is no longer a dataflow"); + .unwrap_or_else(|e| panic!("The chunk circuit is no longer a dataflow graph: {e}")); let node_map = circ .insert_subgraph(root, &self.circ, &subgraph) .expect("Failed to insert the chunk subgraph") .node_map; - let [inp, out] = circ.get_io(root).unwrap(); let mut input_map = HashMap::with_capacity(self.inputs.len()); let mut output_map = HashMap::with_capacity(self.outputs.len()); - for (&connection, incoming) in self.inputs.iter().zip(subgraph.incoming_ports().iter()) { - let incoming = incoming.iter().map(|&(node, port)| { - if node == out { - // TODO: Add a map for directly connected Input connection -> Output Wire. - panic!("Chunk input directly connected to the output. This is not currently supported."); - } - (*node_map.get(&node).unwrap(),port) - }).collect_vec(); - input_map.insert(connection, incoming); + // Translate each connection from the chunk input into a [`ConnectionTarget`]. + // + // Connections to an inserted node are translated into a [`ConnectionTarget::InsertedNode`]. + // Connections from the input directly into the output become a [`ConnectionTarget::TransitiveConnection`]. + for (&connection, chunk_inp_port) in + self.inputs.iter().zip(self.circ.node_outputs(chunk_inp)) + { + let connection_targets: Vec = self + .circ + .linked_ports(chunk_inp, chunk_inp_port) + .map(|(node, port)| { + if node == chunk_out { + // This was a direct wire from the chunk input to the output. Use the output's [`ChunkConnection`]. + let output_connection = self.outputs[port.index()]; + ConnectionTarget::TransitiveConnection(output_connection) + } else { + // Translate the original chunk node into the inserted node. + (*node_map.get(&node).unwrap(), port).into() + } + }) + .collect(); + input_map.insert(connection, connection_targets); } - for (&wire, &(node, port)) in self.outputs.iter().zip(subgraph.outgoing_ports().iter()) { - if node == inp { - // TODO: Add a map for directly connected Input Wire -> Output Wire. - panic!("Chunk input directly connected to the output. This is not currently supported."); - } - output_map.insert(wire, (*node_map.get(&node).unwrap(), port)); + for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) { + let (node, port) = self + .circ + .linked_ports(chunk_out, chunk_out_port) + .exactly_one() + .ok() + .unwrap(); + let target = if node == chunk_inp { + // This was a direct wire from the chunk output to the input. Use the input's [`ChunkConnection`]. + let input_connection = self.inputs[port.index()]; + ConnectionTarget::TransitiveConnection(input_connection) + } else { + // Translate the original chunk node into the inserted node. + (*node_map.get(&node).unwrap(), port).into() + }; + output_map.insert(wire, target); + } + + ChunkInsertResult { + incoming_connections: input_map, + outgoing_connections: output_map, + } + } + + /// Compute the return value for `insert` when the chunk is empty (Subgraph would throw an error in this case). + /// + /// TODO: Support empty Subgraphs in Hugr? + fn empty_chunk_insert_result(&self) -> ChunkInsertResult { + let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap(); + let mut input_map = HashMap::with_capacity(self.inputs.len()); + let mut output_map = HashMap::with_capacity(self.outputs.len()); + + for (&connection, chunk_inp_port) in + self.inputs.iter().zip(self.circ.node_outputs(chunk_inp)) + { + let connection_targets: Vec = self + .circ + .linked_ports(chunk_inp, chunk_inp_port) + .map(|(node, port)| { + assert_eq!(node, chunk_out); + let output_connection = self.outputs[port.index()]; + ConnectionTarget::TransitiveConnection(output_connection) + }) + .collect(); + input_map.insert(connection, connection_targets); + } + + for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) { + let (node, port) = self + .circ + .linked_ports(chunk_out, chunk_out_port) + .exactly_one() + .ok() + .unwrap(); + assert_eq!(node, chunk_inp); + let input_connection = self.inputs[port.index()]; + output_map.insert( + wire, + ConnectionTarget::TransitiveConnection(input_connection), + ); } ChunkInsertResult { @@ -141,13 +213,25 @@ impl Chunk { } /// A map from the original input/output [`ChunkConnection`]s to an inserted chunk's inputs and outputs. +#[derive(Debug, Clone)] struct ChunkInsertResult { /// A map from incoming connections to a chunk, to the new node and incoming port targets. /// /// A chunk may specify multiple targets to be connected to a single incoming `ChunkConnection`. - pub incoming_connections: HashMap>, + pub incoming_connections: HashMap>, /// A map from outgoing connections from a chunk, to the new node and outgoing port target. - pub outgoing_connections: HashMap, + pub outgoing_connections: HashMap, +} + +/// The target of a chunk connection in a reassembled circuit. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)] +enum ConnectionTarget { + /// The target is a single node and port. + #[from] + InsertedNode(Node, Port), + /// The link goes directly to the opposite boundary, without an intermediary + /// node. + TransitiveConnection(ChunkConnection), } /// An utility for splitting a circuit into chunks, and reassembling them @@ -172,7 +256,7 @@ pub struct CircuitChunks { output_connections: Vec, /// The split circuits. - pub chunks: Vec, + chunks: Vec, } impl CircuitChunks { @@ -180,7 +264,7 @@ impl CircuitChunks { /// /// The circuit is split into chunks of at most `max_size` gates. pub fn split(circ: &impl Circuit, max_size: usize) -> Self { - Self::split_with_cost(circ, max_size, |_| 1) + Self::split_with_cost(circ, max_size.saturating_sub(1), |_| 1) } /// Split a circuit into chunks. @@ -197,12 +281,12 @@ impl CircuitChunks { let [circ_input, circ_output] = circ.get_io(circ.root()).unwrap(); let input_connections = circ .node_outputs(circ_input) - .map(|port| Wire::new(circ_input, port)) + .map(|port| Wire::new(circ_input, port).into()) .collect(); let output_connections = circ .node_inputs(circ_output) .flat_map(|p| circ.linked_ports(circ_output, p)) - .map(|(n, p)| Wire::new(n, p)) + .map(|(n, p)| Wire::new(n, p).into()) .collect(); let mut chunks = Vec::new(); @@ -256,6 +340,24 @@ impl CircuitChunks { let mut sources: HashMap = HashMap::new(); let mut targets: HashMap> = HashMap::new(); + // A map for `ChunkConnection`s that have been merged into another (due + // to identity wires in the updated chunks). + // + // Maps each `ChunkConnection` to the `ChunkConnection` it has been + // merged into. + // + // This is a poor man's Union Find. Since we traverse the chunks in + // order, we can assume that already seen connections will not be merged + // again. + let mut transitive_connections: HashMap = HashMap::new(); + let get_merged_connection = |transitive_connections: &HashMap<_, _>, connection| { + transitive_connections + .get(&connection) + .copied() + .unwrap_or(connection) + }; + + // Register the source ports for the `ChunkConnections` in the circuit input. for (&connection, port) in self .input_connections .iter() @@ -263,13 +365,6 @@ impl CircuitChunks { { sources.insert(connection, (reassembled_input, port)); } - for (&connection, port) in self - .output_connections - .iter() - .zip(reassembled.node_inputs(reassembled_output)) - { - targets.insert(connection, vec![(reassembled_output, port)]); - } for chunk in self.chunks { // Insert the chunk circuit without its input/output nodes. @@ -277,11 +372,57 @@ impl CircuitChunks { incoming_connections, outgoing_connections, } = chunk.insert(&mut reassembled, root); - // Reconnect the chunk's inputs and outputs in the reassembled circuit. - sources.extend(outgoing_connections); - incoming_connections.into_iter().for_each(|(wire, tgts)| { - targets.entry(wire).or_default().extend(tgts); - }); + // Associate the chunk's inserted inputs and outputs to the + // `ChunkConnection` identifiers, so we can re-connect everything + // afterwards. + // + // The chunk may return `ConnectionTarget::TransitiveConnection`s to + // indicate that a `ChunkConnection` has been merged into another + // (due to an identity wire). + for (connection, conn_target) in outgoing_connections { + match conn_target { + ConnectionTarget::InsertedNode(node, port) => { + // The output of a chunk always has fresh `ChunkConnection`s. + sources.insert(connection, (node, port)); + } + ConnectionTarget::TransitiveConnection(merged_connection) => { + // The output's `ChunkConnection` has been merged into one of the input's. + let merged_connection = + get_merged_connection(&transitive_connections, merged_connection); + transitive_connections.insert(connection, merged_connection); + } + } + } + for (connection, conn_targets) in incoming_connections { + // The connection in the chunk's input may have been merged into a earlier one. + let connection = get_merged_connection(&transitive_connections, connection); + for tgt in conn_targets { + match tgt { + ConnectionTarget::InsertedNode(node, port) => { + targets.entry(connection).or_default().push((node, port)); + } + ConnectionTarget::TransitiveConnection(_merged_connection) => { + // The merge has been registered when scanning the + // outgoing_connections, so we don't need to do + // anything here. + } + } + } + } + } + + // Register the target ports for the `ChunkConnections` into the circuit output. + for (&connection, port) in self + .output_connections + .iter() + .zip(reassembled.node_inputs(reassembled_output)) + { + // The connection in the chunk's input may have been merged into a earlier one. + let connection = get_merged_connection(&transitive_connections, connection); + targets + .entry(connection) + .or_default() + .push((reassembled_output, port)); } // Reconnect the different chunks. @@ -389,14 +530,51 @@ mod test { }) .unwrap(); - let mut chunks = CircuitChunks::split(&circ, 3); + let chunks = CircuitChunks::split(&circ, 3); - // Rearrange the chunks so nodes are inserted in a new order. - chunks.chunks.reverse(); + assert_eq!(chunks.len(), 3); let mut reassembled = chunks.reassemble().unwrap(); reassembled.infer_and_validate(®ISTRY).unwrap(); assert_eq!(circ.circuit_hash(), reassembled.circuit_hash()); } + + #[test] + fn reassemble_empty() { + let circ = build_simple_circuit(3, |circ| { + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::H, [0])?; + circ.append(T2Op::H, [1])?; + Ok(()) + }) + .unwrap(); + + let circ_1q_id = build_simple_circuit(1, |_| Ok(())).unwrap(); + let circ_2q_id_h = build_simple_circuit(2, |circ| { + circ.append(T2Op::H, [0])?; + Ok(()) + }) + .unwrap(); + + let mut chunks = CircuitChunks::split(&circ, 1); + + // Replace the Hs with identities, and the CX with an identity and an H gate. + chunks[0] = circ_2q_id_h.clone(); + chunks[1] = circ_1q_id.clone(); + chunks[2] = circ_1q_id.clone(); + + let mut reassembled = chunks.reassemble().unwrap(); + + reassembled.infer_and_validate(®ISTRY).unwrap(); + + assert_eq!(reassembled.commands().count(), 1); + let h = reassembled.commands().next().unwrap().node(); + + let [inp, out] = reassembled.get_io(reassembled.root()).unwrap(); + assert_eq!( + &reassembled.output_neighbours(inp).collect_vec(), + &[h, out, out] + ); + } }