From cf542b4f22218461c58a16194b9a21f1c70f9898 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Thu, 6 Jun 2024 17:39:54 +0100 Subject: [PATCH] feat: `HugrView::extract_hugr` to extract regions into owned hugrs. (#1173) Add a `extract_hugr` method to extract regions into owned `Hugr`s. The implementation is short-circuited for `Hugr` to avoid unnecessary clones. Closes #1171. --- hugr-core/src/hugr.rs | 6 ++++ hugr-core/src/hugr/views.rs | 37 ++++++++++++++++++++++++- hugr-core/src/hugr/views/descendants.rs | 4 ++- hugr-core/src/hugr/views/sibling.rs | 21 +++++++++++++- 4 files changed, 65 insertions(+), 3 deletions(-) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index f54e56655..0319df0dc 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -263,6 +263,12 @@ impl Hugr { } } + /// Set the root node of the hugr. + pub(crate) fn set_root(&mut self, root: Node) { + self.hierarchy.detach(self.root); + self.root = root.pg_index(); + } + /// Add a node to the graph. pub(crate) fn add_node(&mut self, nodetype: NodeType) -> Node { let node = self diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 8ad14190f..607f88768 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -26,7 +26,8 @@ use portgraph::{multiportgraph, LinkView, PortView}; use super::internal::HugrInternals; use super::{ - Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, DEFAULT_NODETYPE, + Hugr, HugrError, HugrMut, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, + DEFAULT_NODETYPE, }; use crate::extension::ExtensionRegistry; use crate::ops::handle::NodeHandle; @@ -508,6 +509,21 @@ pub trait HierarchyView<'a>: RootTagged + Sized { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result; } +/// A trait for [`HugrView`]s that can be extracted into a valid HUGR containing +/// only the nodes and edges of the view. +pub trait ExtractHugr: HugrView + Sized { + /// Extracts the view into an owned HUGR, rooted at the view's root node + /// and containing only the nodes and edges of the view. + fn extract_hugr(self) -> Hugr { + let mut hugr = Hugr::default(); + let old_root = hugr.root(); + let new_root = hugr.insert_from_view(old_root, &self).new_root; + hugr.set_root(new_root); + hugr.remove_node(old_root); + hugr + } +} + fn check_tag(hugr: &impl HugrView, node: Node) -> Result<(), HugrError> { let actual = hugr.get_optype(node).tag(); let required = Required::TAG; @@ -529,6 +545,25 @@ impl RootTagged for &mut Hugr { type RootHandle = Node; } +// Explicit implementation to avoid cloning the Hugr. +impl ExtractHugr for Hugr { + fn extract_hugr(self) -> Hugr { + self + } +} + +impl ExtractHugr for &Hugr { + fn extract_hugr(self) -> Hugr { + self.clone() + } +} + +impl ExtractHugr for &mut Hugr { + fn extract_hugr(self) -> Hugr { + self.clone() + } +} + impl> HugrView for T { /// An Iterator over the nodes in a Hugr(View) type Nodes<'a> = MapInto, Node> where Self: 'a; diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 05d57ce90..bee6dbd26 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -9,7 +9,7 @@ use crate::hugr::HugrError; use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node, Port}; -use super::{check_tag, HierarchyView, HugrInternals, HugrView, RootTagged}; +use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>; @@ -175,6 +175,8 @@ where } } +impl<'g, Root: NodeHandle> ExtractHugr for DescendantsGraph<'g, Root> {} + impl<'g, Root> super::HugrInternals for DescendantsGraph<'g, Root> where Root: NodeHandle, diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 90723b9e4..25bc477a4 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -11,7 +11,7 @@ use crate::hugr::{HugrError, HugrMut}; use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node, Port}; -use super::{check_tag, HierarchyView, HugrInternals, HugrView, RootTagged}; +use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>; @@ -205,6 +205,8 @@ where } } +impl<'g, Root: NodeHandle> ExtractHugr for SiblingGraph<'g, Root> {} + impl<'g, Root> HugrInternals for SiblingGraph<'g, Root> where Root: NodeHandle, @@ -268,6 +270,8 @@ impl<'g, Root: NodeHandle> SiblingMut<'g, Root> { } } +impl<'g, Root: NodeHandle> ExtractHugr for SiblingMut<'g, Root> {} + impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> { type Portgraph<'p> = FlatRegionGraph<'p> where 'g: 'p, Root: 'p; @@ -484,4 +488,19 @@ mod test { let nested_sib_mut = SiblingMut::::try_new(&mut sib_mut, root); assert!(nested_sib_mut.is_err()); } + + #[rstest] + fn extract_hugr() -> Result<(), Box> { + let (hugr, def, _inner) = make_module_hgr()?; + + let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?; + let extracted = region.extract_hugr(); + + let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?; + + assert_eq!(region.node_count(), extracted.node_count()); + assert_eq!(region.root_type(), extracted.root_type()); + + Ok(()) + } }