Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add CircuitMut trait #138

Merged
merged 9 commits into from
Sep 29, 2023
193 changes: 189 additions & 4 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@ pub use hash::CircuitHash;

use hugr::HugrView;

use derive_more::From;
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::{NodeType, PortIndex};
use hugr::ops::dataflow::IOTrait;
pub use hugr::ops::OpType;
use hugr::ops::{Input, Output, DFG};
use hugr::types::FunctionType;
pub use hugr::types::{EdgeKind, Signature, Type, TypeRow};
pub use hugr::{Node, Port, Wire};
use itertools::Itertools;
use portgraph::Direction;
use thiserror::Error;

use self::units::{filter, FilteredUnits, Units};

Expand Down Expand Up @@ -120,18 +128,164 @@ pub trait Circuit: HugrView {
}
}

/// Remove an empty wire in a dataflow HUGR.
///
/// The wire to be removed is identified by the index of the outgoing port
/// at the circuit input node.
///
/// This will change the circuit signature and will shift all ports after
/// the removed wire by -1. If the wire is connected to the output node,
/// this will also change the signature output and shift the ports after
/// the removed wire by -1.
///
/// This will return an error if the wire is not empty or if a HugrError
/// occurs.
#[allow(dead_code)]
pub(crate) fn remove_empty_wire(
circ: &mut impl HugrMut,
input_port: usize,
) -> Result<(), CircuitMutError> {
let [inp, out] = circ.get_io(circ.root()).expect("no IO nodes found at root");
if input_port >= circ.num_outputs(inp) {
return Err(CircuitMutError::InvalidPortOffset(input_port));
}
let input_port = Port::new_outgoing(input_port);
let link = circ
.linked_ports(inp, input_port)
.at_most_one()
.map_err(|_| CircuitMutError::DeleteNonEmptyWire(input_port.index()))?;
if link.is_some() && link.unwrap().0 != out {
return Err(CircuitMutError::DeleteNonEmptyWire(input_port.index()));
}
if link.is_some() {
circ.disconnect(inp, input_port)?;
}

// Shift ports at input
shift_ports(circ, inp, input_port, circ.num_outputs(inp))?;
// Shift ports at output
if let Some((out, output_port)) = link {
shift_ports(circ, out, output_port, circ.num_inputs(out))?;
}
// Update input node, output node (if necessary) and root signatures.
update_signature(circ, input_port.index(), link.map(|(_, p)| p.index()));
// Resize ports at input/output node
circ.set_num_ports(inp, 0, circ.num_outputs(inp) - 1);
if let Some((out, _)) = link {
circ.set_num_ports(out, circ.num_inputs(out) - 1, 0);
}
Ok(())
}

/// Errors that can occur when mutating a circuit.
#[derive(Debug, Clone, Error, PartialEq, Eq, From)]
pub enum CircuitMutError {
/// A Hugr error occurred.
#[error("Hugr error: {0:?}")]
HugrError(hugr::hugr::HugrError),
/// The wire to be deleted is not empty.
#[from(ignore)]
#[error("Wire {0} cannot be deleted: not empty")]
DeleteNonEmptyWire(usize),
/// The wire does not exist.
#[from(ignore)]
#[error("Wire {0} does not exist")]
InvalidPortOffset(usize),
}

/// Shift ports in range (free_port + 1 .. max_ind) by -1.
fn shift_ports<C: HugrMut + ?Sized>(
circ: &mut C,
node: Node,
mut free_port: Port,
max_ind: usize,
) -> Result<Port, hugr::hugr::HugrError> {
let dir = free_port.direction();
let port_range = (free_port.index() + 1..max_ind).map(|p| Port::new(dir, p));
for port in port_range {
let links = circ.linked_ports(node, port).collect_vec();
if !links.is_empty() {
circ.disconnect(node, port)?;
}
for (other_n, other_p) in links {
// TODO: simplify when CQCL-DEV/hugr#565 is resolved
match dir {
Direction::Incoming => circ.connect(other_n, other_p, node, free_port),
Direction::Outgoing => circ.connect(node, free_port, other_n, other_p),
}?;
}
free_port = port;
}
Ok(free_port)
}

// Update the signature of circ when removing the in_index-th input wire and
// the out_index-th output wire.
fn update_signature<C: HugrMut + Circuit + ?Sized>(
circ: &mut C,
in_index: usize,
out_index: Option<usize>,
) {
let inp = circ.input();
// Update input node
let inp_types: TypeRow = {
let OpType::Input(Input { types }) = circ.get_optype(inp).clone() else {
panic!("invalid circuit")
};
let mut types = types.into_owned();
types.remove(in_index);
types.into()
};
let new_inp_op = Input::new(inp_types.clone());
let inp_exts = circ.get_nodetype(inp).input_extensions().cloned();
circ.replace_op(inp, NodeType::new(new_inp_op, inp_exts));

// Update output node if necessary.
let out_types = out_index.map(|out_index| {
let out = circ.output();
let out_types: TypeRow = {
let OpType::Output(Output { types }) = circ.get_optype(out).clone() else {
panic!("invalid circuit")
};
let mut types = types.into_owned();
types.remove(out_index);
types.into()
};
let new_out_op = Output::new(out_types.clone());
let inp_exts = circ.get_nodetype(out).input_extensions().cloned();
circ.replace_op(out, NodeType::new(new_out_op, inp_exts));
out_types
});

// Update root
let OpType::DFG(DFG { mut signature, .. }) = circ.get_optype(circ.root()).clone() else {
panic!("invalid circuit")
};
signature.input = inp_types;
if let Some(out_types) = out_types {
signature.output = out_types;
}
let new_dfg_op = DFG { signature };
let inp_exts = circ.get_nodetype(circ.root()).input_extensions().cloned();
circ.replace_op(circ.root(), NodeType::new(new_dfg_op, inp_exts));
}

impl<T> Circuit for T where T: HugrView {}

#[cfg(test)]
mod tests {
use hugr::Hugr;
use hugr::{
builder::{DFGBuilder, DataflowHugr},
extension::{prelude::BOOL_T, PRELUDE_REGISTRY},
Hugr,
};

use crate::{circuit::Circuit, json::load_tk1_json_str};
use super::*;
use crate::{json::load_tk1_json_str, utils::build_simple_circuit, T2Op};

fn test_circuit() -> Hugr {
load_tk1_json_str(
r#"{
"phase": "0",
r#"{ "phase": "0",
"bits": [["c", [0]]],
"qubits": [["q", [0]], ["q", [1]]],
"commands": [
Expand Down Expand Up @@ -160,4 +314,35 @@ mod tests {
assert_eq!(circ.linear_units().count(), 3);
assert_eq!(circ.qubits().count(), 2);
}

#[test]
fn remove_qubit() {
let mut circ = build_simple_circuit(2, |circ| {
circ.append(T2Op::X, [0])?;
Ok(())
})
.unwrap();

assert_eq!(circ.qubit_count(), 2);
assert!(remove_empty_wire(&mut circ, 1).is_ok());
assert_eq!(circ.qubit_count(), 1);
assert_eq!(
remove_empty_wire(&mut circ, 0).unwrap_err(),
CircuitMutError::DeleteNonEmptyWire(0)
);
}

#[test]
fn remove_bit() {
let h = DFGBuilder::new(FunctionType::new(vec![BOOL_T], vec![])).unwrap();
let mut circ = h.finish_hugr_with_outputs([], &PRELUDE_REGISTRY).unwrap();

assert_eq!(circ.units().count(), 1);
assert!(remove_empty_wire(&mut circ, 0).is_ok());
assert_eq!(circ.units().count(), 0);
assert_eq!(
remove_empty_wire(&mut circ, 2).unwrap_err(),
CircuitMutError::InvalidPortOffset(2)
);
}
}