diff --git a/src/circuit.rs b/src/circuit.rs index 8f9e9a22..5bf0d1ff 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -9,10 +9,18 @@ pub use hash::CircuitHash; use hugr::HugrView; +use derive_more::From; +use hugr::hugr::hugrmut::HugrMut; +use hugr::hugr::{NodeType, PortIndex}; +use hugr::ops::dataflow::IOTrait; pub use hugr::ops::OpType; +use hugr::ops::{Input, Output, DFG}; use hugr::types::FunctionType; pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Node, Port, Wire}; +use itertools::Itertools; +use portgraph::Direction; +use thiserror::Error; use self::units::{filter, FilteredUnits, Units}; @@ -120,18 +128,164 @@ pub trait Circuit: HugrView { } } +/// Remove an empty wire in a dataflow HUGR. +/// +/// The wire to be removed is identified by the index of the outgoing port +/// at the circuit input node. +/// +/// This will change the circuit signature and will shift all ports after +/// the removed wire by -1. If the wire is connected to the output node, +/// this will also change the signature output and shift the ports after +/// the removed wire by -1. +/// +/// This will return an error if the wire is not empty or if a HugrError +/// occurs. +#[allow(dead_code)] +pub(crate) fn remove_empty_wire( + circ: &mut impl HugrMut, + input_port: usize, +) -> Result<(), CircuitMutError> { + let [inp, out] = circ.get_io(circ.root()).expect("no IO nodes found at root"); + if input_port >= circ.num_outputs(inp) { + return Err(CircuitMutError::InvalidPortOffset(input_port)); + } + let input_port = Port::new_outgoing(input_port); + let link = circ + .linked_ports(inp, input_port) + .at_most_one() + .map_err(|_| CircuitMutError::DeleteNonEmptyWire(input_port.index()))?; + if link.is_some() && link.unwrap().0 != out { + return Err(CircuitMutError::DeleteNonEmptyWire(input_port.index())); + } + if link.is_some() { + circ.disconnect(inp, input_port)?; + } + + // Shift ports at input + shift_ports(circ, inp, input_port, circ.num_outputs(inp))?; + // Shift ports at output + if let Some((out, output_port)) = link { + shift_ports(circ, out, output_port, circ.num_inputs(out))?; + } + // Update input node, output node (if necessary) and root signatures. + update_signature(circ, input_port.index(), link.map(|(_, p)| p.index())); + // Resize ports at input/output node + circ.set_num_ports(inp, 0, circ.num_outputs(inp) - 1); + if let Some((out, _)) = link { + circ.set_num_ports(out, circ.num_inputs(out) - 1, 0); + } + Ok(()) +} + +/// Errors that can occur when mutating a circuit. +#[derive(Debug, Clone, Error, PartialEq, Eq, From)] +pub enum CircuitMutError { + /// A Hugr error occurred. + #[error("Hugr error: {0:?}")] + HugrError(hugr::hugr::HugrError), + /// The wire to be deleted is not empty. + #[from(ignore)] + #[error("Wire {0} cannot be deleted: not empty")] + DeleteNonEmptyWire(usize), + /// The wire does not exist. + #[from(ignore)] + #[error("Wire {0} does not exist")] + InvalidPortOffset(usize), +} + +/// Shift ports in range (free_port + 1 .. max_ind) by -1. +fn shift_ports( + circ: &mut C, + node: Node, + mut free_port: Port, + max_ind: usize, +) -> Result { + let dir = free_port.direction(); + let port_range = (free_port.index() + 1..max_ind).map(|p| Port::new(dir, p)); + for port in port_range { + let links = circ.linked_ports(node, port).collect_vec(); + if !links.is_empty() { + circ.disconnect(node, port)?; + } + for (other_n, other_p) in links { + // TODO: simplify when CQCL-DEV/hugr#565 is resolved + match dir { + Direction::Incoming => circ.connect(other_n, other_p, node, free_port), + Direction::Outgoing => circ.connect(node, free_port, other_n, other_p), + }?; + } + free_port = port; + } + Ok(free_port) +} + +// Update the signature of circ when removing the in_index-th input wire and +// the out_index-th output wire. +fn update_signature( + circ: &mut C, + in_index: usize, + out_index: Option, +) { + let inp = circ.input(); + // Update input node + let inp_types: TypeRow = { + let OpType::Input(Input { types }) = circ.get_optype(inp).clone() else { + panic!("invalid circuit") + }; + let mut types = types.into_owned(); + types.remove(in_index); + types.into() + }; + let new_inp_op = Input::new(inp_types.clone()); + let inp_exts = circ.get_nodetype(inp).input_extensions().cloned(); + circ.replace_op(inp, NodeType::new(new_inp_op, inp_exts)); + + // Update output node if necessary. + let out_types = out_index.map(|out_index| { + let out = circ.output(); + let out_types: TypeRow = { + let OpType::Output(Output { types }) = circ.get_optype(out).clone() else { + panic!("invalid circuit") + }; + let mut types = types.into_owned(); + types.remove(out_index); + types.into() + }; + let new_out_op = Output::new(out_types.clone()); + let inp_exts = circ.get_nodetype(out).input_extensions().cloned(); + circ.replace_op(out, NodeType::new(new_out_op, inp_exts)); + out_types + }); + + // Update root + let OpType::DFG(DFG { mut signature, .. }) = circ.get_optype(circ.root()).clone() else { + panic!("invalid circuit") + }; + signature.input = inp_types; + if let Some(out_types) = out_types { + signature.output = out_types; + } + let new_dfg_op = DFG { signature }; + let inp_exts = circ.get_nodetype(circ.root()).input_extensions().cloned(); + circ.replace_op(circ.root(), NodeType::new(new_dfg_op, inp_exts)); +} + impl Circuit for T where T: HugrView {} #[cfg(test)] mod tests { - use hugr::Hugr; + use hugr::{ + builder::{DFGBuilder, DataflowHugr}, + extension::{prelude::BOOL_T, PRELUDE_REGISTRY}, + Hugr, + }; - use crate::{circuit::Circuit, json::load_tk1_json_str}; + use super::*; + use crate::{json::load_tk1_json_str, utils::build_simple_circuit, T2Op}; fn test_circuit() -> Hugr { load_tk1_json_str( - r#"{ - "phase": "0", + r#"{ "phase": "0", "bits": [["c", [0]]], "qubits": [["q", [0]], ["q", [1]]], "commands": [ @@ -160,4 +314,35 @@ mod tests { assert_eq!(circ.linear_units().count(), 3); assert_eq!(circ.qubits().count(), 2); } + + #[test] + fn remove_qubit() { + let mut circ = build_simple_circuit(2, |circ| { + circ.append(T2Op::X, [0])?; + Ok(()) + }) + .unwrap(); + + assert_eq!(circ.qubit_count(), 2); + assert!(remove_empty_wire(&mut circ, 1).is_ok()); + assert_eq!(circ.qubit_count(), 1); + assert_eq!( + remove_empty_wire(&mut circ, 0).unwrap_err(), + CircuitMutError::DeleteNonEmptyWire(0) + ); + } + + #[test] + fn remove_bit() { + let h = DFGBuilder::new(FunctionType::new(vec![BOOL_T], vec![])).unwrap(); + let mut circ = h.finish_hugr_with_outputs([], &PRELUDE_REGISTRY).unwrap(); + + assert_eq!(circ.units().count(), 1); + assert!(remove_empty_wire(&mut circ, 0).is_ok()); + assert_eq!(circ.units().count(), 0); + assert_eq!( + remove_empty_wire(&mut circ, 2).unwrap_err(), + CircuitMutError::InvalidPortOffset(2) + ); + } }