diff --git a/tket2/src/passes/tuple_unpack.rs b/tket2/src/passes/tuple_unpack.rs index 43e566f7..84d5875d 100644 --- a/tket2/src/passes/tuple_unpack.rs +++ b/tket2/src/passes/tuple_unpack.rs @@ -3,7 +3,7 @@ use core::panic; use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; -use hugr::ops::{OpTrait, OpType}; +use hugr::ops::{MakeTuple, OpTrait, OpType}; use hugr::types::Type; use hugr::{HugrView, Node}; use itertools::Itertools; @@ -72,41 +72,49 @@ fn make_rewrite(circ: &Circuit, cmd: Command) -> Option Some(remove_pack_unpack( - circ, - &tuple_types, - tuple_node, - unpack_nodes, - )), - false => { - // TODO: Add a rewrite to remove some of the unpack operations. - None - } - } + let num_other_outputs = links.len() - unpack_nodes.len(); + Some(remove_pack_unpack( + circ, + &tuple_types, + tuple_node, + unpack_nodes, + num_other_outputs, + )) } -/// Returns a rewrite to remove a tuple pack operation that's only followed by unpack operations. +/// Returns a rewrite to remove a tuple pack operation that's followed by unpack operations, +/// and `other_tuple_links` other operations. fn remove_pack_unpack( circ: &Circuit, tuple_types: &[Type], pack_node: Node, unpack_nodes: Vec, + num_other_outputs: usize, ) -> CircuitRewrite { - let num_outputs = tuple_types.len() * unpack_nodes.len(); + let num_unpack_outputs = tuple_types.len() * unpack_nodes.len(); let mut nodes = unpack_nodes; nodes.push(pack_node); let subcirc = Subcircuit::try_from_nodes(nodes, circ).unwrap(); - let replacement = DFGBuilder::new(subcirc.signature(circ)).unwrap(); - let wires = replacement - .input_wires() - .cycle() - .take(num_outputs) - .collect_vec(); + let mut replacement = DFGBuilder::new(subcirc.signature(circ)).unwrap(); + let mut outputs = Vec::with_capacity(num_unpack_outputs + num_other_outputs); + + // If needed, re-add the tuple pack node and connect its output to the tuple outputs. + if num_other_outputs > 0 { + let op = MakeTuple::new(tuple_types.to_vec().into()); + let [tuple] = replacement + .add_dataflow_op(op, replacement.input_wires()) + .unwrap() + .outputs_arr(); + outputs.extend(std::iter::repeat(tuple).take(num_other_outputs)) + } + + // Wire the inputs directly to the unpack outputs + outputs.extend(replacement.input_wires().cycle().take(num_unpack_outputs)); + let replacement = replacement - .finish_prelude_hugr_with_outputs(wires) + .finish_prelude_hugr_with_outputs(outputs) .unwrap_or_else(|e| { panic!("Failed to create replacement for removing tuple pack/unpack operations. {e}") }) @@ -205,8 +213,6 @@ mod test { #[rstest] #[case::simple(simple_pack_unpack(), 1, 0)] #[case::multi(multi_unpack(), 1, 0)] - // TODO: Partial unpack is not currently supported. - #[ignore = "Unimplemented."] #[case::partial(partial_unpack(), 1, 1)] fn test_pack_unpack( #[case] mut circ: Circuit,