Skip to content

Commit

Permalink
fix: Support replacement with non-DFG-root targets
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Jun 17, 2024
1 parent f18aa4d commit 741e264
Show file tree
Hide file tree
Showing 8 changed files with 362 additions and 78 deletions.
320 changes: 273 additions & 47 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ lto = "thin"

[workspace]
resolver = "2"
members = ["tket2", "tket2-py", "compile-rewriter", "badger-optimiser", "tket2-hseries"]
members = [
"tket2",
"tket2-py",
"compile-rewriter",
"badger-optimiser",
"tket2-hseries",
]
default-members = ["tket2", "tket2-hseries"]

[workspace.package]
Expand All @@ -21,6 +27,7 @@ missing_docs = "warn"
tket2 = { path = "./tket2" }
hugr = "0.5.1"
hugr-cli = "0.1.1"
hugr-core = "0.2.0"
portgraph = "0.12"
pyo3 = "0.21.2"
itertools = "0.13.0"
Expand Down
1 change: 1 addition & 0 deletions tket2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ serde_yaml = { workspace = true }
portmatching = { workspace = true, optional = true, features = ["serde"] }
derive_more = { workspace = true }
hugr = { workspace = true }
hugr-core = { workspace = true }
portgraph = { workspace = true, features = ["serde"] }
strum_macros = { workspace = true }
strum = { workspace = true }
Expand Down
42 changes: 36 additions & 6 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ use std::iter::Sum;

pub use command::{Command, CommandIterator};
pub use hash::CircuitHash;
use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView};
use itertools::Either::{Left, Right};

use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::NodeType;
use hugr::ops::dataflow::IOTrait;
use hugr::ops::{Input, NamedOp, OpParent, OpTag, OpTrait, Output};
use hugr::ops::{Input, NamedOp, OpParent, OpTag, OpTrait, Output, DFG};
use hugr::types::{FunctionType, PolyFuncType};
use hugr::{Hugr, PortIndex};
use hugr::{HugrView, OutgoingPort};
use hugr_core::hugr::internal::HugrMutInternals;
use itertools::Itertools;
use thiserror::Error;

Expand Down Expand Up @@ -299,6 +301,37 @@ impl<T: HugrView> Circuit<T> {
// TODO: See comment in `dot_string`.
self.hugr.mermaid_string()
}

/// Extracts the circuit into a new owned HUGR containing the circuit at the root.
/// Replaces the circuit container operation with an [`OpType::Dfg`].
///
/// Regions that are not descendants of the parent node are not included in the new HUGR.
/// This may invalidate calls to functions defined elsewhere. Make sure to inline any
/// external functions before calling this method.
pub fn extract_dfg(self) -> Result<Circuit<Hugr>, CircuitMutError>
where
T: ExtractHugr,
{
let mut circ = if self.parent == self.hugr.root() {
self.to_owned()
} else {
let view: DescendantsGraph = DescendantsGraph::try_new(&self.hugr, self.parent)
.expect("Circuit parent was not a dataflow container.");
view.extract_hugr().into()
};

// Replace the parent node with a DFG node, if necessary.
let nodetype = circ.hugr.get_nodetype(circ.parent());
if !matches!(nodetype.op(), OpType::DFG(_)) {
let dfg = DFG {
signature: circ.circuit_signature(),
};
let input_extensions = nodetype.input_extensions().cloned();
let nodetype = NodeType::new(OpType::DFG(dfg), input_extensions);
circ.hugr.replace_op(circ.parent(), nodetype)?;
}
Ok(circ)
}
}

impl<T: HugrView> From<T> for Circuit<T> {
Expand Down Expand Up @@ -648,12 +681,9 @@ mod tests {
#[case] circ: Circuit,
#[case] qubits: usize,
#[case] bits: usize,
#[case] _name: Option<&str>,
#[case] name: Option<&str>,
) {
// TODO: The decoder discards the circuit name.
// This requires decoding circuits into `FuncDefn` nodes instead of `Dfg`,
// but currently that causes errors with the replacement methods.
//assert_eq!(circ.name(), name);
assert_eq!(circ.name(), name);
assert_eq!(circ.circuit_signature().input_count(), qubits + bits);
assert_eq!(circ.circuit_signature().output_count(), qubits + bits);
assert_eq!(circ.qubit_count(), qubits);
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ impl From<InvalidSubgraph> for InvalidPatternMatch {
InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _),
) => InvalidPatternMatch::NotConvex,
InvalidSubgraph::EmptySubgraph => InvalidPatternMatch::EmptyMatch,
InvalidSubgraph::NoSharedParent | InvalidSubgraph::InvalidBoundary(_) => {
InvalidSubgraph::NoSharedParent { .. } | InvalidSubgraph::InvalidBoundary(_) => {
InvalidPatternMatch::InvalidSubcircuit
}
other => InvalidPatternMatch::Other(other),
Expand Down
40 changes: 27 additions & 13 deletions tket2/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub use ecc_rewriter::ECCRewriter;
use derive_more::{From, Into};
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph};
use hugr::hugr::views::ExtractHugr;
use hugr::{
hugr::{views::SiblingSubgraph, Rewrite, SimpleReplacementError},
SimpleReplacement,
Expand All @@ -33,7 +34,7 @@ impl Subcircuit {
/// Create a new subcircuit induced from a set of nodes.
pub fn try_from_nodes(
nodes: impl Into<Vec<Node>>,
circ: &Circuit,
circ: &Circuit<impl HugrView>,
) -> Result<Self, InvalidSubgraph> {
let subgraph = SiblingSubgraph::try_from_nodes(nodes, circ.hugr())?;
Ok(Self { subgraph })
Expand All @@ -49,16 +50,25 @@ impl Subcircuit {
self.subgraph.node_count()
}

/// Create a rewrite rule to replace the subcircuit.
/// Create a rewrite rule to replace the subcircuit with a new circuit.
///
/// # Parameters
/// * `circuit` - The base circuit that contains the subcircuit.
/// * `replacement` - The new circuit to replace the subcircuit with.
pub fn create_rewrite(
&self,
source: &Circuit,
target: Circuit,
circuit: &Circuit<impl HugrView>,
replacement: Circuit<impl ExtractHugr>,
) -> Result<CircuitRewrite, InvalidReplacement> {
Ok(CircuitRewrite(self.subgraph.create_simple_replacement(
source.hugr(),
target.into_hugr(),
)?))
// The replacement must be a Dfg rooted hugr.
let replacement = replacement
.extract_dfg()
.unwrap_or_else(|e| panic!("{}", e))
.into_hugr();
Ok(CircuitRewrite(
self.subgraph
.create_simple_replacement(circuit.hugr(), replacement)?,
))
}
}

Expand All @@ -69,13 +79,17 @@ pub struct CircuitRewrite(SimpleReplacement);
impl CircuitRewrite {
/// Create a new rewrite rule.
pub fn try_new(
source_position: &Subcircuit,
source: &Circuit<impl HugrView>,
target: Circuit,
circuit_position: &Subcircuit,
circuit: &Circuit<impl HugrView>,
replacement: Circuit<impl ExtractHugr>,
) -> Result<Self, InvalidReplacement> {
source_position
let replacement = replacement
.extract_dfg()
.unwrap_or_else(|e| panic!("{}", e))
.into_hugr();
circuit_position
.subgraph
.create_simple_replacement(source.hugr(), target.into_hugr())
.create_simple_replacement(circuit.hugr(), replacement)
.map(Self)
}

Expand Down
12 changes: 6 additions & 6 deletions tket2/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::mem;

use hugr::builder::{CircuitBuilder, Container, DFGBuilder, Dataflow, DataflowHugr};
use hugr::builder::{CircuitBuilder, Container, Dataflow, DataflowHugr, FunctionBuilder};
use hugr::extension::prelude::QB_T;

use hugr::types::FunctionType;
Expand All @@ -22,13 +22,13 @@ use super::{METADATA_B_REGISTERS, METADATA_Q_REGISTERS};
use crate::extension::{LINEAR_BIT, REGISTRY};
use crate::symbolic_constant_op;

/// The state of an in-progress [`DFGBuilder`] being built from a [`SerialCircuit`].
/// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`].
///
/// Mostly used to define helper internal methods.
#[derive(Debug, PartialEq)]
pub(super) struct JsonDecoder {
/// The Hugr being built.
pub hugr: DFGBuilder<Hugr>,
pub hugr: FunctionBuilder<Hugr>,
/// The dangling wires of the builder.
/// Used to generate [`CircuitBuilder`]s.
dangling_wires: Vec<Wire>,
Expand Down Expand Up @@ -66,8 +66,8 @@ impl JsonDecoder {
);
// .with_extension_delta(&ExtensionSet::singleton(&TKET1_EXTENSION_ID));

// TODO: Use a FunctionBuilder and store the circuit name there.
let mut dfg = DFGBuilder::new(sig).unwrap();
let name = serialcirc.name.clone().unwrap_or_default();
let mut dfg = FunctionBuilder::new(name, sig.into()).unwrap();

// Metadata. The circuit requires "name", and we store other things that
// should pass through the serialization roundtrip.
Expand Down Expand Up @@ -128,7 +128,7 @@ impl JsonDecoder {
}

/// Apply a function to the internal hugr builder viewed as a [`CircuitBuilder`].
fn with_circ_builder(&mut self, f: impl FnOnce(&mut CircuitBuilder<DFGBuilder<Hugr>>)) {
fn with_circ_builder(&mut self, f: impl FnOnce(&mut CircuitBuilder<FunctionBuilder<Hugr>>)) {
let mut circ = self.hugr.as_circuit(mem::take(&mut self.dangling_wires));
f(&mut circ);
self.dangling_wires = circ.finish();
Expand Down
14 changes: 10 additions & 4 deletions tket2/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
use hugr::builder::{Container, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder};
use hugr::extension::PRELUDE_REGISTRY;
use hugr::ops::handle::NodeHandle;
use hugr::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use hugr::std_extensions::arithmetic::float_types;
use hugr::types::{Type, TypeBound};
use hugr::Hugr;
use hugr::{
builder::{BuildError, CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr},
builder::{BuildError, CircuitBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
types::FunctionType,
};
Expand All @@ -21,10 +23,12 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool {
#[allow(unused)]
pub(crate) fn build_simple_circuit<F>(num_qubits: usize, f: F) -> Result<Circuit, BuildError>
where
F: FnOnce(&mut CircuitBuilder<DFGBuilder<Hugr>>) -> Result<(), BuildError>,
F: FnOnce(&mut CircuitBuilder<FunctionBuilder<Hugr>>) -> Result<(), BuildError>,
{
let qb_row = vec![QB_T; num_qubits];
let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?;
let signature =
FunctionType::new(qb_row.clone(), qb_row).with_extension_delta(float_types::EXTENSION_ID);
let mut h = FunctionBuilder::new("main", signature.into())?;

let qbs = h.input_wires();

Expand All @@ -33,7 +37,9 @@ where
f(&mut circ)?;

let qbs = circ.finish();
let hugr = h.finish_hugr_with_outputs(qbs, &PRELUDE_REGISTRY)?;

// The float ops registry is required to define constant float values.
let hugr = h.finish_hugr_with_outputs(qbs, &FLOAT_OPS_REGISTRY)?;
Ok(hugr.into())
}

Expand Down

0 comments on commit 741e264

Please sign in to comment.