Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support empty wires in CircuitChunks #172

Merged
merged 7 commits into from
Oct 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 223 additions & 45 deletions src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use std::collections::HashMap;
use std::mem;
use std::ops::{Index, IndexMut};

use derive_more::From;
use hugr::builder::{Container, FunctionBuilder};
use hugr::extension::ExtensionSet;
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::sibling_subgraph::ConvexChecker;
use hugr::hugr::views::{HierarchyView, SiblingGraph, SiblingSubgraph};
use hugr::hugr::{HugrError, NodeMetadata};
use hugr::hugr::{HugrError, NodeMetadata, PortIndex};
use hugr::ops::handle::DataflowParentID;
use hugr::ops::OpType;
use hugr::types::{FunctionType, Signature};
Expand All @@ -34,7 +35,8 @@ use tket_json_rs::circuit_json::SerialCircuit;
///
/// When reassembling the circuit, the input/output wires of each chunk are
/// re-linked by matching these identifiers.
pub type ChunkConnection = Wire;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)]
pub struct ChunkConnection(Wire);

/// A chunk of a circuit.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -81,13 +83,13 @@ impl Chunk {
.exactly_one()
.ok()
.unwrap();
Wire::new(out_node, out_port)
Wire::new(out_node, out_port).into()
})
.collect();
let outputs = subgraph
.outgoing_ports()
.iter()
.map(|&(node, port)| Wire::new(node, port))
.map(|&(node, port)| Wire::new(node, port).into())
.collect();
Self {
circ: extracted,
Expand All @@ -97,40 +99,110 @@ impl Chunk {
}

/// Insert the chunk back into a circuit.
//
// TODO: The new chunk may have input ports directly connected to outputs. We have to take care of those.
#[allow(clippy::type_complexity)]
pub(self) fn insert(&self, circ: &mut impl HugrMut, root: Node) -> ChunkInsertResult {
if self.circ.children(self.circ.root()).nth(2).is_none() {
// The chunk is empty. We don't need to insert anything.
return self.empty_chunk_insert_result();
}

let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap();
let chunk_sg: SiblingGraph<'_, DataflowParentID> =
SiblingGraph::try_new(&self.circ, self.circ.root()).unwrap();
// Insert the chunk circuit into the original circuit.
let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&chunk_sg)
.expect("The chunk circuit is no longer a dataflow");
.unwrap_or_else(|e| panic!("The chunk circuit is no longer a dataflow graph: {e}"));
let node_map = circ
.insert_subgraph(root, &self.circ, &subgraph)
.expect("Failed to insert the chunk subgraph")
.node_map;

let [inp, out] = circ.get_io(root).unwrap();
let mut input_map = HashMap::with_capacity(self.inputs.len());
let mut output_map = HashMap::with_capacity(self.outputs.len());

for (&connection, incoming) in self.inputs.iter().zip(subgraph.incoming_ports().iter()) {
let incoming = incoming.iter().map(|&(node, port)| {
if node == out {
// TODO: Add a map for directly connected Input connection -> Output Wire.
panic!("Chunk input directly connected to the output. This is not currently supported.");
}
(*node_map.get(&node).unwrap(),port)
}).collect_vec();
input_map.insert(connection, incoming);
// Translate each connection from the chunk input into a [`ConnectionTarget`].
//
// Connections to an inserted node are translated into a [`ConnectionTarget::InsertedNode`].
// Connections from the input directly into the output become a [`ConnectionTarget::TransitiveConnection`].
for (&connection, chunk_inp_port) in
self.inputs.iter().zip(self.circ.node_outputs(chunk_inp))
{
let connection_targets: Vec<ConnectionTarget> = self
.circ
.linked_ports(chunk_inp, chunk_inp_port)
.map(|(node, port)| {
if node == chunk_out {
// This was a direct wire from the chunk input to the output. Use the output's [`ChunkConnection`].
let output_connection = self.outputs[port.index()];
ConnectionTarget::TransitiveConnection(output_connection)
} else {
// Translate the original chunk node into the inserted node.
(*node_map.get(&node).unwrap(), port).into()
}
})
.collect();
input_map.insert(connection, connection_targets);
}

for (&wire, &(node, port)) in self.outputs.iter().zip(subgraph.outgoing_ports().iter()) {
if node == inp {
// TODO: Add a map for directly connected Input Wire -> Output Wire.
panic!("Chunk input directly connected to the output. This is not currently supported.");
}
output_map.insert(wire, (*node_map.get(&node).unwrap(), port));
for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) {
let (node, port) = self
.circ
.linked_ports(chunk_out, chunk_out_port)
.exactly_one()
.ok()
.unwrap();
let target = if node == chunk_inp {
// This was a direct wire from the chunk output to the input. Use the input's [`ChunkConnection`].
let input_connection = self.inputs[port.index()];
ConnectionTarget::TransitiveConnection(input_connection)
} else {
// Translate the original chunk node into the inserted node.
(*node_map.get(&node).unwrap(), port).into()
};
output_map.insert(wire, target);
}

ChunkInsertResult {
incoming_connections: input_map,
outgoing_connections: output_map,
}
}

/// Compute the return value for `insert` when the chunk is empty (Subgraph would throw an error in this case).
///
/// TODO: Support empty Subgraphs in Hugr?
fn empty_chunk_insert_result(&self) -> ChunkInsertResult {
let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap();
let mut input_map = HashMap::with_capacity(self.inputs.len());
let mut output_map = HashMap::with_capacity(self.outputs.len());

for (&connection, chunk_inp_port) in
self.inputs.iter().zip(self.circ.node_outputs(chunk_inp))
{
let connection_targets: Vec<ConnectionTarget> = self
.circ
.linked_ports(chunk_inp, chunk_inp_port)
.map(|(node, port)| {
assert_eq!(node, chunk_out);
let output_connection = self.outputs[port.index()];
ConnectionTarget::TransitiveConnection(output_connection)
})
.collect();
input_map.insert(connection, connection_targets);
}

for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) {
let (node, port) = self
.circ
.linked_ports(chunk_out, chunk_out_port)
.exactly_one()
.ok()
.unwrap();
assert_eq!(node, chunk_inp);
let input_connection = self.inputs[port.index()];
output_map.insert(
wire,
ConnectionTarget::TransitiveConnection(input_connection),
);
}

ChunkInsertResult {
Expand All @@ -141,13 +213,25 @@ impl Chunk {
}

/// A map from the original input/output [`ChunkConnection`]s to an inserted chunk's inputs and outputs.
#[derive(Debug, Clone)]
struct ChunkInsertResult {
/// A map from incoming connections to a chunk, to the new node and incoming port targets.
///
/// A chunk may specify multiple targets to be connected to a single incoming `ChunkConnection`.
pub incoming_connections: HashMap<ChunkConnection, Vec<(Node, Port)>>,
pub incoming_connections: HashMap<ChunkConnection, Vec<ConnectionTarget>>,
/// A map from outgoing connections from a chunk, to the new node and outgoing port target.
pub outgoing_connections: HashMap<ChunkConnection, (Node, Port)>,
pub outgoing_connections: HashMap<ChunkConnection, ConnectionTarget>,
}

/// The target of a chunk connection in a reassembled circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, From)]
enum ConnectionTarget {
/// The target is a single node and port.
#[from]
InsertedNode(Node, Port),
/// The link goes directly to the opposite boundary, without an intermediary
/// node.
TransitiveConnection(ChunkConnection),
}

/// An utility for splitting a circuit into chunks, and reassembling them
Expand All @@ -172,15 +256,15 @@ pub struct CircuitChunks {
output_connections: Vec<ChunkConnection>,

/// The split circuits.
pub chunks: Vec<Chunk>,
chunks: Vec<Chunk>,
}

impl CircuitChunks {
/// Split a circuit into chunks.
///
/// The circuit is split into chunks of at most `max_size` gates.
pub fn split(circ: &impl Circuit, max_size: usize) -> Self {
Self::split_with_cost(circ, max_size, |_| 1)
Self::split_with_cost(circ, max_size.saturating_sub(1), |_| 1)
}

/// Split a circuit into chunks.
Expand All @@ -197,12 +281,12 @@ impl CircuitChunks {
let [circ_input, circ_output] = circ.get_io(circ.root()).unwrap();
let input_connections = circ
.node_outputs(circ_input)
.map(|port| Wire::new(circ_input, port))
.map(|port| Wire::new(circ_input, port).into())
.collect();
let output_connections = circ
.node_inputs(circ_output)
.flat_map(|p| circ.linked_ports(circ_output, p))
.map(|(n, p)| Wire::new(n, p))
.map(|(n, p)| Wire::new(n, p).into())
.collect();

let mut chunks = Vec::new();
Expand Down Expand Up @@ -256,32 +340,89 @@ impl CircuitChunks {
let mut sources: HashMap<ChunkConnection, (Node, Port)> = HashMap::new();
let mut targets: HashMap<ChunkConnection, Vec<(Node, Port)>> = HashMap::new();

// A map for `ChunkConnection`s that have been merged into another (due
// to identity wires in the updated chunks).
//
// Maps each `ChunkConnection` to the `ChunkConnection` it has been
// merged into.
//
// This is a poor man's Union Find. Since we traverse the chunks in
// order, we can assume that already seen connections will not be merged
// again.
let mut transitive_connections: HashMap<ChunkConnection, ChunkConnection> = HashMap::new();
let get_merged_connection = |transitive_connections: &HashMap<_, _>, connection| {
transitive_connections
.get(&connection)
.copied()
.unwrap_or(connection)
};

// Register the source ports for the `ChunkConnections` in the circuit input.
for (&connection, port) in self
.input_connections
.iter()
.zip(reassembled.node_outputs(reassembled_input))
{
sources.insert(connection, (reassembled_input, port));
}
for (&connection, port) in self
.output_connections
.iter()
.zip(reassembled.node_inputs(reassembled_output))
{
targets.insert(connection, vec![(reassembled_output, port)]);
}

for chunk in self.chunks {
// Insert the chunk circuit without its input/output nodes.
let ChunkInsertResult {
incoming_connections,
outgoing_connections,
} = chunk.insert(&mut reassembled, root);
// Reconnect the chunk's inputs and outputs in the reassembled circuit.
sources.extend(outgoing_connections);
incoming_connections.into_iter().for_each(|(wire, tgts)| {
targets.entry(wire).or_default().extend(tgts);
});
// Associate the chunk's inserted inputs and outputs to the
// `ChunkConnection` identifiers, so we can re-connect everything
// afterwards.
//
// The chunk may return `ConnectionTarget::TransitiveConnection`s to
// indicate that a `ChunkConnection` has been merged into another
// (due to an identity wire).
for (connection, conn_target) in outgoing_connections {
match conn_target {
ConnectionTarget::InsertedNode(node, port) => {
// The output of a chunk always has fresh `ChunkConnection`s.
sources.insert(connection, (node, port));
}
ConnectionTarget::TransitiveConnection(merged_connection) => {
// The output's `ChunkConnection` has been merged into one of the input's.
let merged_connection =
get_merged_connection(&transitive_connections, merged_connection);
transitive_connections.insert(connection, merged_connection);
}
}
}
for (connection, conn_targets) in incoming_connections {
// The connection in the chunk's input may have been merged into a earlier one.
let connection = get_merged_connection(&transitive_connections, connection);
for tgt in conn_targets {
match tgt {
ConnectionTarget::InsertedNode(node, port) => {
targets.entry(connection).or_default().push((node, port));
}
ConnectionTarget::TransitiveConnection(_merged_connection) => {
// The merge has been registered when scanning the
// outgoing_connections, so we don't need to do
// anything here.
}
}
}
}
}

// Register the target ports for the `ChunkConnections` into the circuit output.
for (&connection, port) in self
.output_connections
.iter()
.zip(reassembled.node_inputs(reassembled_output))
{
// The connection in the chunk's input may have been merged into a earlier one.
let connection = get_merged_connection(&transitive_connections, connection);
targets
.entry(connection)
.or_default()
.push((reassembled_output, port));
}

// Reconnect the different chunks.
Expand Down Expand Up @@ -389,14 +530,51 @@ mod test {
})
.unwrap();

let mut chunks = CircuitChunks::split(&circ, 3);
let chunks = CircuitChunks::split(&circ, 3);

// Rearrange the chunks so nodes are inserted in a new order.
chunks.chunks.reverse();
assert_eq!(chunks.len(), 3);

let mut reassembled = chunks.reassemble().unwrap();

reassembled.infer_and_validate(&REGISTRY).unwrap();
assert_eq!(circ.circuit_hash(), reassembled.circuit_hash());
}

#[test]
fn reassemble_empty() {
let circ = build_simple_circuit(3, |circ| {
circ.append(T2Op::CX, [0, 1])?;
circ.append(T2Op::H, [0])?;
circ.append(T2Op::H, [1])?;
Ok(())
})
.unwrap();

let circ_1q_id = build_simple_circuit(1, |_| Ok(())).unwrap();
let circ_2q_id_h = build_simple_circuit(2, |circ| {
circ.append(T2Op::H, [0])?;
Ok(())
})
.unwrap();

let mut chunks = CircuitChunks::split(&circ, 1);

// Replace the Hs with identities, and the CX with an identity and an H gate.
chunks[0] = circ_2q_id_h.clone();
chunks[1] = circ_1q_id.clone();
chunks[2] = circ_1q_id.clone();

let mut reassembled = chunks.reassemble().unwrap();

reassembled.infer_and_validate(&REGISTRY).unwrap();

assert_eq!(reassembled.commands().count(), 1);
let h = reassembled.commands().next().unwrap().node();

let [inp, out] = reassembled.get_io(reassembled.root()).unwrap();
assert_eq!(
&reassembled.output_neighbours(inp).collect_vec(),
&[h, out, out]
);
}
}