Skip to content

Commit

Permalink
refactor(Rewrite)!: return impl trait in Rewrite trait (#889)
Browse files Browse the repository at this point in the history
Get rid of the associated type in Rewrite trait
  • Loading branch information
Cobord authored Mar 20, 2024
1 parent 3406e88 commit 5381831
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 52 deletions.
11 changes: 2 additions & 9 deletions quantinuum-hugr/src/hugr/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ pub trait Rewrite {
type Error: std::error::Error;
/// The type returned on successful application of the rewrite.
type ApplyResult;
/// The node iterator returned by [`Rewrite::invalidation_set`]
type InvalidationSet<'a>: Iterator<Item = Node> + 'a
where
Self: 'a;

/// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err.
/// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned.
Expand All @@ -47,7 +43,7 @@ pub trait Rewrite {
///
/// Two `impl Rewrite`s can be composed if their invalidation sets are
/// disjoint.
fn invalidation_set(&self) -> Self::InvalidationSet<'_>;
fn invalidation_set(&self) -> impl Iterator<Item = Node>;
}

/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure)
Expand All @@ -60,9 +56,6 @@ pub struct Transactional<R> {
impl<R: Rewrite> Rewrite for Transactional<R> {
type Error = R::Error;
type ApplyResult = R::ApplyResult;
type InvalidationSet<'a> = R::InvalidationSet<'a>
where
Self: 'a;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
Expand Down Expand Up @@ -93,7 +86,7 @@ impl<R: Rewrite> Rewrite for Transactional<R> {
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
self.underlying.invalidation_set()
}
}
8 changes: 2 additions & 6 deletions quantinuum-hugr/src/hugr/rewrite/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ impl Rewrite for RemoveLoadConstant {
// The Const node the LoadConstant was connected to.
type ApplyResult = Node;

type InvalidationSet<'a> = iter::Once<Node>;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
Expand Down Expand Up @@ -64,7 +62,7 @@ impl Rewrite for RemoveLoadConstant {
Ok(source)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
iter::once(self.0)
}
}
Expand All @@ -79,8 +77,6 @@ impl Rewrite for RemoveConst {
// The parent of the Const node.
type ApplyResult = Node;

type InvalidationSet<'a> = iter::Once<Node>;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
Expand Down Expand Up @@ -108,7 +104,7 @@ impl Rewrite for RemoveConst {
Ok(parent)
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
iter::once(self.0)
}
}
Expand Down
4 changes: 1 addition & 3 deletions quantinuum-hugr/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ impl Rewrite for InlineDFG {
type ApplyResult = [Node; 3];
type Error = InlineDFGError;

type InvalidationSet<'a> = <[Node; 1] as IntoIterator>::IntoIter;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, h: &impl crate::HugrView) -> Result<(), Self::Error> {
Expand Down Expand Up @@ -122,7 +120,7 @@ impl Rewrite for InlineDFG {
Ok([n, input, output])
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
[self.0.node()].into_iter()
}
}
Expand Down
5 changes: 1 addition & 4 deletions quantinuum-hugr/src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ impl Rewrite for IdentityInsertion {
type Error = IdentityInsertionError;
/// The inserted node.
type ApplyResult = Node;
type InvalidationSet<'a> = iter::Once<Node>
where
Self: 'a;
const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> {
/*
Expand Down Expand Up @@ -90,7 +87,7 @@ impl Rewrite for IdentityInsertion {
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
iter::once(self.post_node)
}
}
Expand Down
9 changes: 2 additions & 7 deletions quantinuum-hugr/src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection of an existing CFG
use std::collections::{hash_set, HashSet};
use std::iter;
use std::collections::HashSet;

use itertools::Itertools;
use thiserror::Error;
Expand Down Expand Up @@ -101,9 +100,6 @@ impl Rewrite for OutlineCfg {
///
/// [CFG]: OpType::CFG
type ApplyResult = (Node, Node);
type InvalidationSet<'a> = iter::Copied<hash_set::Iter<'a, Node>>
where
Self: 'a;

const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> {
Expand Down Expand Up @@ -216,8 +212,7 @@ impl Rewrite for OutlineCfg {
Ok((new_block, cfg_node))
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
self.blocks.iter().copied()
}
}
Expand Down
8 changes: 1 addition & 7 deletions quantinuum-hugr/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
//! Implementation of the `Replace` operation.
use std::collections::{HashMap, HashSet, VecDeque};
use std::iter::Copied;
use std::slice::Iter;

use itertools::Itertools;
use thiserror::Error;
Expand Down Expand Up @@ -215,10 +213,6 @@ impl Rewrite for Replacement {

type ApplyResult = ();

type InvalidationSet<'a> = Copied<Iter<'a, Node>>
where
Self: 'a;

const UNCHANGED_ON_FAILURE: bool = false;

fn verify(&self, h: &impl crate::HugrView) -> Result<(), Self::Error> {
Expand Down Expand Up @@ -331,7 +325,7 @@ impl Rewrite for Replacement {
Ok(())
}

fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
self.removal.iter().copied()
}
}
Expand Down
19 changes: 3 additions & 16 deletions quantinuum-hugr/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
//! Implementation of the `SimpleReplace` operation.
use std::collections::{hash_map, HashMap};
use std::iter::{self, Copied};
use std::slice;
use std::collections::HashMap;

use crate::hugr::views::SiblingSubgraph;
use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite};
Expand Down Expand Up @@ -55,19 +53,9 @@ impl SimpleReplacement {
}
}

type SubgraphNodesIter<'a> = Copied<slice::Iter<'a, Node>>;
type NuOutNodesIter<'a> = iter::Map<
hash_map::Keys<'a, (Node, IncomingPort), IncomingPort>,
fn(&'a (Node, IncomingPort)) -> Node,
>;

impl Rewrite for SimpleReplacement {
type Error = SimpleReplacementError;
type ApplyResult = ();
type InvalidationSet<'a> = iter::Chain<SubgraphNodesIter<'a>, NuOutNodesIter<'a>>
where
Self: 'a;

const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, _h: &impl HugrView) -> Result<(), SimpleReplacementError> {
Expand Down Expand Up @@ -184,10 +172,9 @@ impl Rewrite for SimpleReplacement {
}

#[inline]
fn invalidation_set(&self) -> Self::InvalidationSet<'_> {
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
let subcirc = self.subgraph.nodes().iter().copied();
let get_node: fn(&(Node, IncomingPort)) -> Node = |key| key.0;
let out_neighs = self.nu_out.keys().map(get_node);
let out_neighs = self.nu_out.keys().map(|key| key.0);
subcirc.chain(out_neighs)
}
}
Expand Down

0 comments on commit 5381831

Please sign in to comment.