From e9de224693e5a60452c9e1053d9b6e134e3b2238 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 19 Oct 2023 11:00:46 +0100 Subject: [PATCH] chore: Update hugr dependency --- Cargo.toml | 2 +- src/passes/chunks.rs | 7 +++---- src/passes/commutation.rs | 15 ++++++++++++++- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e072b668..8a0d3362 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ members = ["pyrs", "compile-rewriter", "taso-optimiser"] [workspace.dependencies] -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "9254ac7" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "9195d15" } portgraph = { version = "0.9", features = ["serde"] } pyo3 = { version = "0.19" } itertools = { version = "0.11.0" } diff --git a/src/passes/chunks.rs b/src/passes/chunks.rs index fd12f13a..cb0652af 100644 --- a/src/passes/chunks.rs +++ b/src/passes/chunks.rs @@ -113,8 +113,7 @@ impl Chunk { .unwrap_or_else(|e| panic!("The chunk circuit is no longer a dataflow graph: {e}")); let node_map = circ .insert_subgraph(root, &self.circ, &subgraph) - .expect("Failed to insert the chunk subgraph") - .node_map; + .expect("Failed to insert the chunk subgraph"); let mut input_map = HashMap::with_capacity(self.inputs.len()); let mut output_map = HashMap::with_capacity(self.outputs.len()); @@ -536,7 +535,7 @@ mod test { let mut reassembled = chunks.reassemble().unwrap(); - reassembled.infer_and_validate(®ISTRY).unwrap(); + reassembled.update_validate(®ISTRY).unwrap(); assert_eq!(circ.circuit_hash(), reassembled.circuit_hash()); } @@ -566,7 +565,7 @@ mod test { let mut reassembled = chunks.reassemble().unwrap(); - reassembled.infer_and_validate(®ISTRY).unwrap(); + reassembled.update_validate(®ISTRY).unwrap(); assert_eq!(reassembled.commands().count(), 1); let h = reassembled.commands().next().unwrap().node(); diff --git a/src/passes/commutation.rs b/src/passes/commutation.rs index 1d56d400..8df3a294 100644 --- a/src/passes/commutation.rs +++ b/src/passes/commutation.rs @@ -228,6 +228,10 @@ impl Rewrite for PullForward { type ApplyResult = (); + type InvalidationSet<'a> = std::vec::IntoIter + where + Self: 'a; + const UNCHANGED_ON_FAILURE: bool = false; fn verify(&self, _h: &impl HugrView) -> Result<(), Self::Error> { @@ -294,6 +298,15 @@ impl Rewrite for PullForward { } Ok(()) } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + // TODO: This could avoid creating a vec, but it'll be easier to do once + // return position impl trait is available. + let mut nodes = vec![self.command.node()]; + let next_nodes = self.new_nexts.values().map(|c| c.node()); + nodes.extend(next_nodes); + nodes.into_iter() + } } /// Pass which greedily commutes operations forwards in order to reduce depth. @@ -603,7 +616,7 @@ mod test { let node_count = case.node_count(); let depth_before = depth(&case); let move_count = apply_greedy_commutation(&mut case).unwrap(); - case.infer_and_validate(®ISTRY).unwrap(); + case.update_validate(®ISTRY).unwrap(); assert_eq!( move_count, expected_moves,