Skip to content

Commit

Permalink
fix: Support empty wires in CircuitChunks (#172)
Browse files Browse the repository at this point in the history
Now TASO optimizations that result in identity wires won't panic when
reassembling the split chunks.
  • Loading branch information
aborgna-q authored Oct 5, 2023
1 parent 4174e14 commit f42746f
Showing 1 changed file with 223 additions and 45 deletions.
268 changes: 223 additions & 45 deletions src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use std::collections::HashMap;
use std::mem;
use std::ops::{Index, IndexMut};

use derive_more::From;
use hugr::builder::{Container, FunctionBuilder};
use hugr::extension::ExtensionSet;
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::sibling_subgraph::ConvexChecker;
use hugr::hugr::views::{HierarchyView, SiblingGraph, SiblingSubgraph};
use hugr::hugr::{HugrError, NodeMetadata};
use hugr::hugr::{HugrError, NodeMetadata, PortIndex};
use hugr::ops::handle::DataflowParentID;
use hugr::ops::OpType;
use hugr::types::{FunctionType, Signature};
Expand All @@ -34,7 +35,8 @@ use tket_json_rs::circuit_json::SerialCircuit;
///
/// When reassembling the circuit, the input/output wires of each chunk are
/// re-linked by matching these identifiers.
pub type ChunkConnection = Wire;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)]
pub struct ChunkConnection(Wire);

/// A chunk of a circuit.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -81,13 +83,13 @@ impl Chunk {
.exactly_one()
.ok()
.unwrap();
Wire::new(out_node, out_port)
Wire::new(out_node, out_port).into()
})
.collect();
let outputs = subgraph
.outgoing_ports()
.iter()
.map(|&(node, port)| Wire::new(node, port))
.map(|&(node, port)| Wire::new(node, port).into())
.collect();
Self {
circ: extracted,
Expand All @@ -97,40 +99,110 @@ impl Chunk {
}

/// Insert the chunk back into a circuit.
//
// TODO: The new chunk may have input ports directly connected to outputs. We have to take care of those.
#[allow(clippy::type_complexity)]
pub(self) fn insert(&self, circ: &mut impl HugrMut, root: Node) -> ChunkInsertResult {
if self.circ.children(self.circ.root()).nth(2).is_none() {
// The chunk is empty. We don't need to insert anything.
return self.empty_chunk_insert_result();
}

let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap();
let chunk_sg: SiblingGraph<'_, DataflowParentID> =
SiblingGraph::try_new(&self.circ, self.circ.root()).unwrap();
// Insert the chunk circuit into the original circuit.
let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&chunk_sg)
.expect("The chunk circuit is no longer a dataflow");
.unwrap_or_else(|e| panic!("The chunk circuit is no longer a dataflow graph: {e}"));
let node_map = circ
.insert_subgraph(root, &self.circ, &subgraph)
.expect("Failed to insert the chunk subgraph")
.node_map;

let [inp, out] = circ.get_io(root).unwrap();
let mut input_map = HashMap::with_capacity(self.inputs.len());
let mut output_map = HashMap::with_capacity(self.outputs.len());

for (&connection, incoming) in self.inputs.iter().zip(subgraph.incoming_ports().iter()) {
let incoming = incoming.iter().map(|&(node, port)| {
if node == out {
// TODO: Add a map for directly connected Input connection -> Output Wire.
panic!("Chunk input directly connected to the output. This is not currently supported.");
}
(*node_map.get(&node).unwrap(),port)
}).collect_vec();
input_map.insert(connection, incoming);
// Translate each connection from the chunk input into a [`ConnectionTarget`].
//
// Connections to an inserted node are translated into a [`ConnectionTarget::InsertedNode`].
// Connections from the input directly into the output become a [`ConnectionTarget::TransitiveConnection`].
for (&connection, chunk_inp_port) in
self.inputs.iter().zip(self.circ.node_outputs(chunk_inp))
{
let connection_targets: Vec<ConnectionTarget> = self
.circ
.linked_ports(chunk_inp, chunk_inp_port)
.map(|(node, port)| {
if node == chunk_out {
// This was a direct wire from the chunk input to the output. Use the output's [`ChunkConnection`].
let output_connection = self.outputs[port.index()];
ConnectionTarget::TransitiveConnection(output_connection)
} else {
// Translate the original chunk node into the inserted node.
(*node_map.get(&node).unwrap(), port).into()
}
})
.collect();
input_map.insert(connection, connection_targets);
}

for (&wire, &(node, port)) in self.outputs.iter().zip(subgraph.outgoing_ports().iter()) {
if node == inp {
// TODO: Add a map for directly connected Input Wire -> Output Wire.
panic!("Chunk input directly connected to the output. This is not currently supported.");
}
output_map.insert(wire, (*node_map.get(&node).unwrap(), port));
for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) {
let (node, port) = self
.circ
.linked_ports(chunk_out, chunk_out_port)
.exactly_one()
.ok()
.unwrap();
let target = if node == chunk_inp {
// This was a direct wire from the chunk output to the input. Use the input's [`ChunkConnection`].
let input_connection = self.inputs[port.index()];
ConnectionTarget::TransitiveConnection(input_connection)
} else {
// Translate the original chunk node into the inserted node.
(*node_map.get(&node).unwrap(), port).into()
};
output_map.insert(wire, target);
}

ChunkInsertResult {
incoming_connections: input_map,
outgoing_connections: output_map,
}
}

/// Compute the return value for `insert` when the chunk is empty (Subgraph would throw an error in this case).
///
/// TODO: Support empty Subgraphs in Hugr?
fn empty_chunk_insert_result(&self) -> ChunkInsertResult {
let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap();
let mut input_map = HashMap::with_capacity(self.inputs.len());
let mut output_map = HashMap::with_capacity(self.outputs.len());

for (&connection, chunk_inp_port) in
self.inputs.iter().zip(self.circ.node_outputs(chunk_inp))
{
let connection_targets: Vec<ConnectionTarget> = self
.circ
.linked_ports(chunk_inp, chunk_inp_port)
.map(|(node, port)| {
assert_eq!(node, chunk_out);
let output_connection = self.outputs[port.index()];
ConnectionTarget::TransitiveConnection(output_connection)
})
.collect();
input_map.insert(connection, connection_targets);
}

for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) {
let (node, port) = self
.circ
.linked_ports(chunk_out, chunk_out_port)
.exactly_one()
.ok()
.unwrap();
assert_eq!(node, chunk_inp);
let input_connection = self.inputs[port.index()];
output_map.insert(
wire,
ConnectionTarget::TransitiveConnection(input_connection),
);
}

ChunkInsertResult {
Expand All @@ -141,13 +213,25 @@ impl Chunk {
}

/// A map from the original input/output [`ChunkConnection`]s to an inserted chunk's inputs and outputs.
#[derive(Debug, Clone)]
struct ChunkInsertResult {
/// A map from incoming connections to a chunk, to the new node and incoming port targets.
///
/// A chunk may specify multiple targets to be connected to a single incoming `ChunkConnection`.
pub incoming_connections: HashMap<ChunkConnection, Vec<(Node, Port)>>,
pub incoming_connections: HashMap<ChunkConnection, Vec<ConnectionTarget>>,
/// A map from outgoing connections from a chunk, to the new node and outgoing port target.
pub outgoing_connections: HashMap<ChunkConnection, (Node, Port)>,
pub outgoing_connections: HashMap<ChunkConnection, ConnectionTarget>,
}

/// The target of a chunk connection in a reassembled circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)]
enum ConnectionTarget {
/// The target is a single node and port.
#[from]
InsertedNode(Node, Port),
/// The link goes directly to the opposite boundary, without an intermediary
/// node.
TransitiveConnection(ChunkConnection),
}

/// An utility for splitting a circuit into chunks, and reassembling them
Expand All @@ -172,15 +256,15 @@ pub struct CircuitChunks {
output_connections: Vec<ChunkConnection>,

/// The split circuits.
pub chunks: Vec<Chunk>,
chunks: Vec<Chunk>,
}

impl CircuitChunks {
/// Split a circuit into chunks.
///
/// The circuit is split into chunks of at most `max_size` gates.
pub fn split(circ: &impl Circuit, max_size: usize) -> Self {
Self::split_with_cost(circ, max_size, |_| 1)
Self::split_with_cost(circ, max_size.saturating_sub(1), |_| 1)
}

/// Split a circuit into chunks.
Expand All @@ -197,12 +281,12 @@ impl CircuitChunks {
let [circ_input, circ_output] = circ.get_io(circ.root()).unwrap();
let input_connections = circ
.node_outputs(circ_input)
.map(|port| Wire::new(circ_input, port))
.map(|port| Wire::new(circ_input, port).into())
.collect();
let output_connections = circ
.node_inputs(circ_output)
.flat_map(|p| circ.linked_ports(circ_output, p))
.map(|(n, p)| Wire::new(n, p))
.map(|(n, p)| Wire::new(n, p).into())
.collect();

let mut chunks = Vec::new();
Expand Down Expand Up @@ -256,32 +340,89 @@ impl CircuitChunks {
let mut sources: HashMap<ChunkConnection, (Node, Port)> = HashMap::new();
let mut targets: HashMap<ChunkConnection, Vec<(Node, Port)>> = HashMap::new();

// A map for `ChunkConnection`s that have been merged into another (due
// to identity wires in the updated chunks).
//
// Maps each `ChunkConnection` to the `ChunkConnection` it has been
// merged into.
//
// This is a poor man's Union Find. Since we traverse the chunks in
// order, we can assume that already seen connections will not be merged
// again.
let mut transitive_connections: HashMap<ChunkConnection, ChunkConnection> = HashMap::new();
let get_merged_connection = |transitive_connections: &HashMap<_, _>, connection| {
transitive_connections
.get(&connection)
.copied()
.unwrap_or(connection)
};

// Register the source ports for the `ChunkConnections` in the circuit input.
for (&connection, port) in self
.input_connections
.iter()
.zip(reassembled.node_outputs(reassembled_input))
{
sources.insert(connection, (reassembled_input, port));
}
for (&connection, port) in self
.output_connections
.iter()
.zip(reassembled.node_inputs(reassembled_output))
{
targets.insert(connection, vec![(reassembled_output, port)]);
}

for chunk in self.chunks {
// Insert the chunk circuit without its input/output nodes.
let ChunkInsertResult {
incoming_connections,
outgoing_connections,
} = chunk.insert(&mut reassembled, root);
// Reconnect the chunk's inputs and outputs in the reassembled circuit.
sources.extend(outgoing_connections);
incoming_connections.into_iter().for_each(|(wire, tgts)| {
targets.entry(wire).or_default().extend(tgts);
});
// Associate the chunk's inserted inputs and outputs to the
// `ChunkConnection` identifiers, so we can re-connect everything
// afterwards.
//
// The chunk may return `ConnectionTarget::TransitiveConnection`s to
// indicate that a `ChunkConnection` has been merged into another
// (due to an identity wire).
for (connection, conn_target) in outgoing_connections {
match conn_target {
ConnectionTarget::InsertedNode(node, port) => {
// The output of a chunk always has fresh `ChunkConnection`s.
sources.insert(connection, (node, port));
}
ConnectionTarget::TransitiveConnection(merged_connection) => {
// The output's `ChunkConnection` has been merged into one of the input's.
let merged_connection =
get_merged_connection(&transitive_connections, merged_connection);
transitive_connections.insert(connection, merged_connection);
}
}
}
for (connection, conn_targets) in incoming_connections {
// The connection in the chunk's input may have been merged into a earlier one.
let connection = get_merged_connection(&transitive_connections, connection);
for tgt in conn_targets {
match tgt {
ConnectionTarget::InsertedNode(node, port) => {
targets.entry(connection).or_default().push((node, port));
}
ConnectionTarget::TransitiveConnection(_merged_connection) => {
// The merge has been registered when scanning the
// outgoing_connections, so we don't need to do
// anything here.
}
}
}
}
}

// Register the target ports for the `ChunkConnections` into the circuit output.
for (&connection, port) in self
.output_connections
.iter()
.zip(reassembled.node_inputs(reassembled_output))
{
// The connection in the chunk's input may have been merged into a earlier one.
let connection = get_merged_connection(&transitive_connections, connection);
targets
.entry(connection)
.or_default()
.push((reassembled_output, port));
}

// Reconnect the different chunks.
Expand Down Expand Up @@ -389,14 +530,51 @@ mod test {
})
.unwrap();

let mut chunks = CircuitChunks::split(&circ, 3);
let chunks = CircuitChunks::split(&circ, 3);

// Rearrange the chunks so nodes are inserted in a new order.
chunks.chunks.reverse();
assert_eq!(chunks.len(), 3);

let mut reassembled = chunks.reassemble().unwrap();

reassembled.infer_and_validate(&REGISTRY).unwrap();
assert_eq!(circ.circuit_hash(), reassembled.circuit_hash());
}

#[test]
fn reassemble_empty() {
let circ = build_simple_circuit(3, |circ| {
circ.append(T2Op::CX, [0, 1])?;
circ.append(T2Op::H, [0])?;
circ.append(T2Op::H, [1])?;
Ok(())
})
.unwrap();

let circ_1q_id = build_simple_circuit(1, |_| Ok(())).unwrap();
let circ_2q_id_h = build_simple_circuit(2, |circ| {
circ.append(T2Op::H, [0])?;
Ok(())
})
.unwrap();

let mut chunks = CircuitChunks::split(&circ, 1);

// Replace the Hs with identities, and the CX with an identity and an H gate.
chunks[0] = circ_2q_id_h.clone();
chunks[1] = circ_1q_id.clone();
chunks[2] = circ_1q_id.clone();

let mut reassembled = chunks.reassemble().unwrap();

reassembled.infer_and_validate(&REGISTRY).unwrap();

assert_eq!(reassembled.commands().count(), 1);
let h = reassembled.commands().next().unwrap().node();

let [inp, out] = reassembled.get_io(reassembled.root()).unwrap();
assert_eq!(
&reassembled.output_neighbours(inp).collect_vec(),
&[h, out, out]
);
}
}

0 comments on commit f42746f

Please sign in to comment.