Skip to content

Commit

Permalink
chore: Update hugr, fixing a rewrite error (#122)
Browse files Browse the repository at this point in the history
- Also drops most siblinggraph uses now that we can use Hugr instead.
  • Loading branch information
aborgna-q authored Sep 20, 2023
1 parent 1d436e6 commit 67df468
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ members = ["pyrs", "compile-matcher", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "660fef6e" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "981f4f9" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
5 changes: 1 addition & 4 deletions benches/benchmarks/hash.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use criterion::{black_box, criterion_group, AxisScale, BenchmarkId, Criterion, PlotConfiguration};
use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::HugrView;
use tket2::circuit::CircuitHash;

use super::generators::make_cnot_layers;
Expand All @@ -11,8 +9,7 @@ fn bench_hash_simple(c: &mut Criterion) {

for size in [10, 100, 1_000] {
g.bench_with_input(BenchmarkId::new("hash_simple", size), &size, |b, size| {
let hugr = make_cnot_layers(8, *size);
let circ: SiblingGraph<'_> = SiblingGraph::new(&hugr, hugr.root());
let circ = make_cnot_layers(8, *size);
b.iter(|| black_box(circ.circuit_hash()))
});
}
Expand Down
5 changes: 1 addition & 4 deletions src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,7 @@ impl<'circ, Circ: Circuit> std::fmt::Debug for CommandIterator<'circ, Circ> {

#[cfg(test)]
mod test {
use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::ops::OpName;
use hugr::HugrView;
use itertools::Itertools;

use crate::utils::build_simple_circuit;
Expand All @@ -342,14 +340,13 @@ mod test {

#[test]
fn iterate_commands() {
let hugr = build_simple_circuit(2, |circ| {
let circ = build_simple_circuit(2, |circ| {
circ.append(T2Op::H, [0])?;
circ.append(T2Op::CX, [0, 1])?;
circ.append(T2Op::T, [1])?;
Ok(())
})
.unwrap();
let circ: SiblingGraph<'_> = SiblingGraph::new(&hugr, hugr.root());

assert_eq!(CommandIterator::new(&circ).count(), 3);

Expand Down
12 changes: 4 additions & 8 deletions src/circuit/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ fn hash_node(circ: &impl HugrView, node: Node, state: &mut HashState) -> u64 {

#[cfg(test)]
mod test {
use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::{Hugr, HugrView};
use hugr::Hugr;
use tket_json_rs::circuit_json;

use crate::json::TKETDecode;
Expand All @@ -160,38 +159,35 @@ mod test {

#[test]
fn hash_equality() {
let hugr1 = build_simple_circuit(2, |circ| {
let circ1 = build_simple_circuit(2, |circ| {
circ.append(T2Op::H, [0])?;
circ.append(T2Op::T, [1])?;
circ.append(T2Op::CX, [0, 1])?;
Ok(())
})
.unwrap();
let circ1: SiblingGraph<'_> = SiblingGraph::new(&hugr1, hugr1.root());
let hash1 = circ1.circuit_hash();

// A circuit built in a different order should have the same hash
let hugr2 = build_simple_circuit(2, |circ| {
let circ2 = build_simple_circuit(2, |circ| {
circ.append(T2Op::T, [1])?;
circ.append(T2Op::H, [0])?;
circ.append(T2Op::CX, [0, 1])?;
Ok(())
})
.unwrap();
let circ2: SiblingGraph<'_> = SiblingGraph::new(&hugr2, hugr2.root());
let hash2 = circ2.circuit_hash();

assert_eq!(hash1, hash2);

// Inverting the CX control and target should produce a different hash
let hugr3 = build_simple_circuit(2, |circ| {
let circ3 = build_simple_circuit(2, |circ| {
circ.append(T2Op::T, [1])?;
circ.append(T2Op::H, [0])?;
circ.append(T2Op::CX, [1, 0])?;
Ok(())
})
.unwrap();
let circ3: SiblingGraph<'_> = SiblingGraph::new(&hugr3, hugr3.root());
let hash3 = circ3.circuit_hash();

assert_ne!(hash1, hash3);
Expand Down
7 changes: 2 additions & 5 deletions src/json/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
use std::collections::HashSet;

use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::ops::handle::DfgID;
use hugr::{Hugr, HugrView};
use hugr::Hugr;
use rstest::rstest;
use tket_json_rs::circuit_json::{self, SerialCircuit};

Expand Down Expand Up @@ -58,8 +56,7 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num
let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap();
assert_eq!(ser.commands.len(), num_commands);

let hugr: Hugr = ser.clone().decode().unwrap();
let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(&hugr, hugr.root());
let circ: Hugr = ser.clone().decode().unwrap();

assert_eq!(circ.qubit_count(), num_qubits);

Expand Down
12 changes: 6 additions & 6 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::pyclass;

use crate::extension::REGISTRY;

/// Name of tket 2 extension.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("quantum.tket2");

Expand Down Expand Up @@ -201,6 +203,7 @@ pub fn symbolic_constant_op(s: &str) -> OpType {
vec![TypeArg::Opaque {
arg: CustomTypeArg::new(SYM_EXPR_T.clone(), value).unwrap(),
}],
&REGISTRY,
)
.unwrap()
.into();
Expand Down Expand Up @@ -274,7 +277,7 @@ pub static ref EXTENSION: Extension = {
impl From<T2Op> for LeafOp {
fn from(op: T2Op) -> Self {
EXTENSION
.instantiate_extension_op(op.name(), [])
.instantiate_extension_op(op.name(), [], &REGISTRY)
.unwrap()
.into()
}
Expand Down Expand Up @@ -322,8 +325,7 @@ pub(crate) mod test {

use std::sync::Arc;

use hugr::hugr::views::HierarchyView;
use hugr::{extension::OpDef, hugr::views::SiblingGraph, ops::handle::DfgID, Hugr, HugrView};
use hugr::{extension::OpDef, Hugr};
use rstest::{fixture, rstest};

use crate::{circuit::Circuit, ops::SimpleOpEnum, utils::build_simple_circuit};
Expand Down Expand Up @@ -354,8 +356,6 @@ pub(crate) mod test {

#[rstest]
fn check_t2_bell(t2_bell_circuit: Hugr) {
let circ: SiblingGraph<'_, DfgID> =
SiblingGraph::new(&t2_bell_circuit, t2_bell_circuit.root());
assert_eq!(circ.commands().count(), 2);
assert_eq!(t2_bell_circuit.commands().count(), 2);
}
}
19 changes: 5 additions & 14 deletions src/portmatching/pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ use std::fmt;

use derive_more::{From, Into};
use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError;
use hugr::hugr::views::{DescendantsGraph, HierarchyView};
use hugr::ops::handle::DfgID;
use hugr::{Hugr, HugrView, Port};
use hugr::{Hugr, Port};
use itertools::Itertools;
use portmatching::{HashMap, PatternID};
use pyo3::{prelude::*, types::PyIterator};
Expand All @@ -22,8 +20,7 @@ impl CircuitPattern {
/// Construct a pattern from a TKET1 circuit
#[new]
pub fn py_from_circuit(circ: PyObject) -> PyResult<CircuitPattern> {
let hugr = pyobj_as_hugr(circ)?;
let circ = hugr_as_view(&hugr);
let circ = pyobj_as_hugr(circ)?;
let pattern = CircuitPattern::try_from_circuit(&circ)?;
Ok(pattern)
}
Expand Down Expand Up @@ -54,8 +51,7 @@ impl PatternMatcher {
/// Find all convex matches in a circuit.
#[pyo3(name = "find_matches")]
pub fn py_find_matches(&self, circ: PyObject) -> PyResult<Vec<PyPatternMatch>> {
let hugr = pyobj_as_hugr(circ)?;
let circ = hugr_as_view(&hugr);
let circ = pyobj_as_hugr(circ)?;
self.find_matches(&circ)
.into_iter()
.map(|m| {
Expand Down Expand Up @@ -160,8 +156,7 @@ impl PyPatternMatch {

/// Convert the pattern into a [`CircuitRewrite`].
pub fn to_rewrite(&self, circ: PyObject, replacement: PyObject) -> PyResult<CircuitRewrite> {
let hugr = pyobj_as_hugr(circ)?;
let circ = hugr_as_view(&hugr);
let circ = pyobj_as_hugr(circ)?;
let inputs = self
.inputs
.iter()
Expand All @@ -176,7 +171,7 @@ impl PyPatternMatch {
outputs,
)
.expect("Invalid PyCircuitMatch object")
.to_rewrite(&hugr, pyobj_as_hugr(replacement)?)?;
.to_rewrite(&circ, pyobj_as_hugr(replacement)?)?;
Ok(rewrite)
}
}
Expand Down Expand Up @@ -207,7 +202,3 @@ fn pyobj_as_hugr(circ: PyObject) -> PyResult<Hugr> {
let hugr: Hugr = ser_c.decode()?;
Ok(hugr)
}

fn hugr_as_view(hugr: &Hugr) -> DescendantsGraph<'_, DfgID> {
DescendantsGraph::new(hugr, hugr.root())
}
14 changes: 3 additions & 11 deletions src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ use itertools::Itertools;
use portmatching::PatternID;
use std::path::Path;

use hugr::{
hugr::views::{HierarchyView, SiblingGraph},
ops::handle::DfgID,
Hugr, HugrView,
};
use hugr::Hugr;

use crate::{
circuit::Circuit,
Expand Down Expand Up @@ -138,13 +134,9 @@ fn get_rewrite_rules(rep_sets: &[EqCircClass]) -> Vec<Vec<TargetID>> {
}

fn get_patterns(rep_sets: &[EqCircClass]) -> Vec<Option<CircuitPattern>> {
let all_hugrs = rep_sets.iter().flat_map(|rs| rs.circuits());
let all_circs = all_hugrs
.map(|hugr| SiblingGraph::<DfgID>::new(hugr, hugr.root()))
// TODO: resolve lifetime issues to avoid collecting to vec
.collect_vec();
all_circs
rep_sets
.iter()
.flat_map(|rs| rs.circuits())
.map(|circ| CircuitPattern::try_from_circuit(circ).ok())
.collect()
}
Expand Down

0 comments on commit 67df468

Please sign in to comment.