Skip to content

Commit

Permalink
feat: HugrView::extract_hugr to extract regions into owned hugrs. (#…
Browse files Browse the repository at this point in the history
…1173)

Add a `extract_hugr` method to extract regions into owned `Hugr`s.
The implementation is short-circuited for `Hugr` to avoid unnecessary
clones.

Closes #1171.
  • Loading branch information
aborgna-q authored Jun 6, 2024
1 parent 5da06e1 commit cf542b4
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 3 deletions.
6 changes: 6 additions & 0 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ impl Hugr {
}
}

/// Set the root node of the hugr.
pub(crate) fn set_root(&mut self, root: Node) {
self.hierarchy.detach(self.root);
self.root = root.pg_index();
}

/// Add a node to the graph.
pub(crate) fn add_node(&mut self, nodetype: NodeType) -> Node {
let node = self
Expand Down
37 changes: 36 additions & 1 deletion hugr-core/src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use portgraph::{multiportgraph, LinkView, PortView};

use super::internal::HugrInternals;
use super::{
Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, DEFAULT_NODETYPE,
Hugr, HugrError, HugrMut, NodeMetadata, NodeMetadataMap, NodeType, ValidationError,
DEFAULT_NODETYPE,
};
use crate::extension::ExtensionRegistry;
use crate::ops::handle::NodeHandle;
Expand Down Expand Up @@ -508,6 +509,21 @@ pub trait HierarchyView<'a>: RootTagged + Sized {
fn try_new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError>;
}

/// A trait for [`HugrView`]s that can be extracted into a valid HUGR containing
/// only the nodes and edges of the view.
pub trait ExtractHugr: HugrView + Sized {
/// Extracts the view into an owned HUGR, rooted at the view's root node
/// and containing only the nodes and edges of the view.
fn extract_hugr(self) -> Hugr {
let mut hugr = Hugr::default();
let old_root = hugr.root();
let new_root = hugr.insert_from_view(old_root, &self).new_root;
hugr.set_root(new_root);
hugr.remove_node(old_root);
hugr
}
}

fn check_tag<Required: NodeHandle>(hugr: &impl HugrView, node: Node) -> Result<(), HugrError> {
let actual = hugr.get_optype(node).tag();
let required = Required::TAG;
Expand All @@ -529,6 +545,25 @@ impl RootTagged for &mut Hugr {
type RootHandle = Node;
}

// Explicit implementation to avoid cloning the Hugr.
impl ExtractHugr for Hugr {
fn extract_hugr(self) -> Hugr {
self
}
}

impl ExtractHugr for &Hugr {
fn extract_hugr(self) -> Hugr {
self.clone()
}
}

impl ExtractHugr for &mut Hugr {
fn extract_hugr(self) -> Hugr {
self.clone()
}
}

impl<T: AsRef<Hugr>> HugrView for T {
/// An Iterator over the nodes in a Hugr(View)
type Nodes<'a> = MapInto<multiportgraph::Nodes<'a>, Node> where Self: 'a;
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::hugr::HugrError;
use crate::ops::handle::NodeHandle;
use crate::{Direction, Hugr, Node, Port};

use super::{check_tag, HierarchyView, HugrInternals, HugrView, RootTagged};
use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};

type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;

Expand Down Expand Up @@ -175,6 +175,8 @@ where
}
}

impl<'g, Root: NodeHandle> ExtractHugr for DescendantsGraph<'g, Root> {}

impl<'g, Root> super::HugrInternals for DescendantsGraph<'g, Root>
where
Root: NodeHandle,
Expand Down
21 changes: 20 additions & 1 deletion hugr-core/src/hugr/views/sibling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::hugr::{HugrError, HugrMut};
use crate::ops::handle::NodeHandle;
use crate::{Direction, Hugr, Node, Port};

use super::{check_tag, HierarchyView, HugrInternals, HugrView, RootTagged};
use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};

type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>;

Expand Down Expand Up @@ -205,6 +205,8 @@ where
}
}

impl<'g, Root: NodeHandle> ExtractHugr for SiblingGraph<'g, Root> {}

impl<'g, Root> HugrInternals for SiblingGraph<'g, Root>
where
Root: NodeHandle,
Expand Down Expand Up @@ -268,6 +270,8 @@ impl<'g, Root: NodeHandle> SiblingMut<'g, Root> {
}
}

impl<'g, Root: NodeHandle> ExtractHugr for SiblingMut<'g, Root> {}

impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> {
type Portgraph<'p> = FlatRegionGraph<'p> where 'g: 'p, Root: 'p;

Expand Down Expand Up @@ -484,4 +488,19 @@ mod test {
let nested_sib_mut = SiblingMut::<DataflowParentID>::try_new(&mut sib_mut, root);
assert!(nested_sib_mut.is_err());
}

#[rstest]
fn extract_hugr() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, def, _inner) = make_module_hgr()?;

let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?;
let extracted = region.extract_hugr();

let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?;

assert_eq!(region.node_count(), extracted.node_count());
assert_eq!(region.root_type(), extracted.root_type());

Ok(())
}
}

0 comments on commit cf542b4

Please sign in to comment.