From d823d4052d33902bd0bbeb9437da5819b6e61ce9 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 20 Sep 2023 15:58:05 +0100 Subject: [PATCH] chore: Update hugr, fixing a rewrite error - Also drops most SiblingGraph uses --- Cargo.toml | 2 +- benches/benchmarks/hash.rs | 5 +---- src/circuit/command.rs | 5 +---- src/circuit/hash.rs | 12 ++++-------- src/json/tests.rs | 7 ++----- src/ops.rs | 12 ++++++------ src/portmatching/pyo3.rs | 19 +++++-------------- src/rewrite/ecc_rewriter.rs | 14 +++----------- 8 files changed, 23 insertions(+), 53 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bc32b1a4..41ba49bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ members = ["pyrs", "compile-matcher", "taso-optimiser"] [workspace.dependencies] -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "660fef6e" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "981f4f9" } portgraph = { version = "0.9", features = ["serde"] } pyo3 = { version = "0.19" } itertools = { version = "0.11.0" } diff --git a/benches/benchmarks/hash.rs b/benches/benchmarks/hash.rs index a9f28c87..e1e37da2 100644 --- a/benches/benchmarks/hash.rs +++ b/benches/benchmarks/hash.rs @@ -1,6 +1,4 @@ use criterion::{black_box, criterion_group, AxisScale, BenchmarkId, Criterion, PlotConfiguration}; -use hugr::hugr::views::{HierarchyView, SiblingGraph}; -use hugr::HugrView; use tket2::circuit::CircuitHash; use super::generators::make_cnot_layers; @@ -11,8 +9,7 @@ fn bench_hash_simple(c: &mut Criterion) { for size in [10, 100, 1_000] { g.bench_with_input(BenchmarkId::new("hash_simple", size), &size, |b, size| { - let hugr = make_cnot_layers(8, *size); - let circ: SiblingGraph<'_> = SiblingGraph::new(&hugr, hugr.root()); + let circ = make_cnot_layers(8, *size); b.iter(|| black_box(circ.circuit_hash())) }); } diff --git a/src/circuit/command.rs b/src/circuit/command.rs index ac8b68c2..7ca21054 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -323,9 +323,7 @@ impl<'circ, Circ: Circuit> std::fmt::Debug for CommandIterator<'circ, Circ> { #[cfg(test)] mod test { - use hugr::hugr::views::{HierarchyView, SiblingGraph}; use hugr::ops::OpName; - use hugr::HugrView; use itertools::Itertools; use crate::utils::build_simple_circuit; @@ -342,14 +340,13 @@ mod test { #[test] fn iterate_commands() { - let hugr = build_simple_circuit(2, |circ| { + let circ = build_simple_circuit(2, |circ| { circ.append(T2Op::H, [0])?; circ.append(T2Op::CX, [0, 1])?; circ.append(T2Op::T, [1])?; Ok(()) }) .unwrap(); - let circ: SiblingGraph<'_> = SiblingGraph::new(&hugr, hugr.root()); assert_eq!(CommandIterator::new(&circ).count(), 3); diff --git a/src/circuit/hash.rs b/src/circuit/hash.rs index aaea3b9b..fa2ffb12 100644 --- a/src/circuit/hash.rs +++ b/src/circuit/hash.rs @@ -148,8 +148,7 @@ fn hash_node(circ: &impl HugrView, node: Node, state: &mut HashState) -> u64 { #[cfg(test)] mod test { - use hugr::hugr::views::{HierarchyView, SiblingGraph}; - use hugr::{Hugr, HugrView}; + use hugr::Hugr; use tket_json_rs::circuit_json; use crate::json::TKETDecode; @@ -160,38 +159,35 @@ mod test { #[test] fn hash_equality() { - let hugr1 = build_simple_circuit(2, |circ| { + let circ1 = build_simple_circuit(2, |circ| { circ.append(T2Op::H, [0])?; circ.append(T2Op::T, [1])?; circ.append(T2Op::CX, [0, 1])?; Ok(()) }) .unwrap(); - let circ1: SiblingGraph<'_> = SiblingGraph::new(&hugr1, hugr1.root()); let hash1 = circ1.circuit_hash(); // A circuit built in a different order should have the same hash - let hugr2 = build_simple_circuit(2, |circ| { + let circ2 = build_simple_circuit(2, |circ| { circ.append(T2Op::T, [1])?; circ.append(T2Op::H, [0])?; circ.append(T2Op::CX, [0, 1])?; Ok(()) }) .unwrap(); - let circ2: SiblingGraph<'_> = SiblingGraph::new(&hugr2, hugr2.root()); let hash2 = circ2.circuit_hash(); assert_eq!(hash1, hash2); // Inverting the CX control and target should produce a different hash - let hugr3 = build_simple_circuit(2, |circ| { + let circ3 = build_simple_circuit(2, |circ| { circ.append(T2Op::T, [1])?; circ.append(T2Op::H, [0])?; circ.append(T2Op::CX, [1, 0])?; Ok(()) }) .unwrap(); - let circ3: SiblingGraph<'_> = SiblingGraph::new(&hugr3, hugr3.root()); let hash3 = circ3.circuit_hash(); assert_ne!(hash1, hash3); diff --git a/src/json/tests.rs b/src/json/tests.rs index 76f7cb9c..78c8cc3c 100644 --- a/src/json/tests.rs +++ b/src/json/tests.rs @@ -2,9 +2,7 @@ use std::collections::HashSet; -use hugr::hugr::views::{HierarchyView, SiblingGraph}; -use hugr::ops::handle::DfgID; -use hugr::{Hugr, HugrView}; +use hugr::Hugr; use rstest::rstest; use tket_json_rs::circuit_json::{self, SerialCircuit}; @@ -58,8 +56,7 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap(); assert_eq!(ser.commands.len(), num_commands); - let hugr: Hugr = ser.clone().decode().unwrap(); - let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(&hugr, hugr.root()); + let circ: Hugr = ser.clone().decode().unwrap(); assert_eq!(circ.qubit_count(), num_qubits); diff --git a/src/ops.rs b/src/ops.rs index 782fbd55..3d5bd9e8 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -26,6 +26,8 @@ use thiserror::Error; #[cfg(feature = "pyo3")] use pyo3::pyclass; +use crate::extension::REGISTRY; + /// Name of tket 2 extension. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2"); @@ -201,6 +203,7 @@ pub fn symbolic_constant_op(s: &str) -> OpType { vec![TypeArg::Opaque { arg: CustomTypeArg::new(SYM_EXPR_T.clone(), value).unwrap(), }], + ®ISTRY, ) .unwrap() .into(); @@ -274,7 +277,7 @@ pub static ref EXTENSION: Extension = { impl From for LeafOp { fn from(op: T2Op) -> Self { EXTENSION - .instantiate_extension_op(op.name(), []) + .instantiate_extension_op(op.name(), [], ®ISTRY) .unwrap() .into() } @@ -322,8 +325,7 @@ pub(crate) mod test { use std::sync::Arc; - use hugr::hugr::views::HierarchyView; - use hugr::{extension::OpDef, hugr::views::SiblingGraph, ops::handle::DfgID, Hugr, HugrView}; + use hugr::{extension::OpDef, Hugr}; use rstest::{fixture, rstest}; use crate::{circuit::Circuit, ops::SimpleOpEnum, utils::build_simple_circuit}; @@ -354,8 +356,6 @@ pub(crate) mod test { #[rstest] fn check_t2_bell(t2_bell_circuit: Hugr) { - let circ: SiblingGraph<'_, DfgID> = - SiblingGraph::new(&t2_bell_circuit, t2_bell_circuit.root()); - assert_eq!(circ.commands().count(), 2); + assert_eq!(t2_bell_circuit.commands().count(), 2); } } diff --git a/src/portmatching/pyo3.rs b/src/portmatching/pyo3.rs index 30d213a9..1ff2a7b6 100644 --- a/src/portmatching/pyo3.rs +++ b/src/portmatching/pyo3.rs @@ -4,9 +4,7 @@ use std::fmt; use derive_more::{From, Into}; use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError; -use hugr::hugr::views::{DescendantsGraph, HierarchyView}; -use hugr::ops::handle::DfgID; -use hugr::{Hugr, HugrView, Port}; +use hugr::{Hugr, Port}; use itertools::Itertools; use portmatching::{HashMap, PatternID}; use pyo3::{prelude::*, types::PyIterator}; @@ -22,8 +20,7 @@ impl CircuitPattern { /// Construct a pattern from a TKET1 circuit #[new] pub fn py_from_circuit(circ: PyObject) -> PyResult { - let hugr = pyobj_as_hugr(circ)?; - let circ = hugr_as_view(&hugr); + let circ = pyobj_as_hugr(circ)?; let pattern = CircuitPattern::try_from_circuit(&circ)?; Ok(pattern) } @@ -54,8 +51,7 @@ impl PatternMatcher { /// Find all convex matches in a circuit. #[pyo3(name = "find_matches")] pub fn py_find_matches(&self, circ: PyObject) -> PyResult> { - let hugr = pyobj_as_hugr(circ)?; - let circ = hugr_as_view(&hugr); + let circ = pyobj_as_hugr(circ)?; self.find_matches(&circ) .into_iter() .map(|m| { @@ -160,8 +156,7 @@ impl PyPatternMatch { /// Convert the pattern into a [`CircuitRewrite`]. pub fn to_rewrite(&self, circ: PyObject, replacement: PyObject) -> PyResult { - let hugr = pyobj_as_hugr(circ)?; - let circ = hugr_as_view(&hugr); + let circ = pyobj_as_hugr(circ)?; let inputs = self .inputs .iter() @@ -176,7 +171,7 @@ impl PyPatternMatch { outputs, ) .expect("Invalid PyCircuitMatch object") - .to_rewrite(&hugr, pyobj_as_hugr(replacement)?)?; + .to_rewrite(&circ, pyobj_as_hugr(replacement)?)?; Ok(rewrite) } } @@ -207,7 +202,3 @@ fn pyobj_as_hugr(circ: PyObject) -> PyResult { let hugr: Hugr = ser_c.decode()?; Ok(hugr) } - -fn hugr_as_view(hugr: &Hugr) -> DescendantsGraph<'_, DfgID> { - DescendantsGraph::new(hugr, hugr.root()) -} diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index 05ca925a..40f1b7eb 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -17,11 +17,7 @@ use itertools::Itertools; use portmatching::PatternID; use std::path::Path; -use hugr::{ - hugr::views::{HierarchyView, SiblingGraph}, - ops::handle::DfgID, - Hugr, HugrView, -}; +use hugr::Hugr; use crate::{ circuit::Circuit, @@ -138,13 +134,9 @@ fn get_rewrite_rules(rep_sets: &[EqCircClass]) -> Vec> { } fn get_patterns(rep_sets: &[EqCircClass]) -> Vec> { - let all_hugrs = rep_sets.iter().flat_map(|rs| rs.circuits()); - let all_circs = all_hugrs - .map(|hugr| SiblingGraph::::new(hugr, hugr.root())) - // TODO: resolve lifetime issues to avoid collecting to vec - .collect_vec(); - all_circs + rep_sets .iter() + .flat_map(|rs| rs.circuits()) .map(|circ| CircuitPattern::try_from_circuit(circ).ok()) .collect() }