Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: HugrView::extract_hugr to extract regions into owned hugrs. #1173

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ impl Hugr {
}
}

/// Set the root node of the hugr.
pub(crate) fn set_root(&mut self, root: Node) {
self.hierarchy.detach(self.root);
self.root = root.pg_index();
}

/// Add a node to the graph.
pub(crate) fn add_node(&mut self, nodetype: NodeType) -> Node {
let node = self
Expand Down
37 changes: 36 additions & 1 deletion hugr-core/src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use portgraph::{multiportgraph, LinkView, PortView};

use super::internal::HugrInternals;
use super::{
Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, DEFAULT_NODETYPE,
Hugr, HugrError, HugrMut, NodeMetadata, NodeMetadataMap, NodeType, ValidationError,
DEFAULT_NODETYPE,
};
use crate::extension::ExtensionRegistry;
use crate::ops::handle::NodeHandle;
Expand Down Expand Up @@ -508,6 +509,21 @@ pub trait HierarchyView<'a>: RootTagged + Sized {
fn try_new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError>;
}

/// A trait for [`HugrView`]s that can be extracted into a valid HUGR containing
/// only the nodes and edges of the view.
pub trait ExtractHugr: HugrView + Sized {
/// Extracts the view into an owned HUGR, rooted at the view's root node
/// and containing only the nodes and edges of the view.
fn extract_hugr(self) -> Hugr {
let mut hugr = Hugr::default();
let old_root = hugr.root();
let new_root = hugr.insert_from_view(old_root, &self).new_root;
hugr.set_root(new_root);
hugr.remove_node(old_root);
hugr
}
}

fn check_tag<Required: NodeHandle>(hugr: &impl HugrView, node: Node) -> Result<(), HugrError> {
let actual = hugr.get_optype(node).tag();
let required = Required::TAG;
Expand All @@ -529,6 +545,25 @@ impl RootTagged for &mut Hugr {
type RootHandle = Node;
}

// Explicit implementation to avoid cloning the Hugr.
impl ExtractHugr for Hugr {
fn extract_hugr(self) -> Hugr {
self
}
}

impl ExtractHugr for &Hugr {
fn extract_hugr(self) -> Hugr {
self.clone()
}
}

impl ExtractHugr for &mut Hugr {
fn extract_hugr(self) -> Hugr {
self.clone()
}
}

impl<T: AsRef<Hugr>> HugrView for T {
/// An Iterator over the nodes in a Hugr(View)
type Nodes<'a> = MapInto<multiportgraph::Nodes<'a>, Node> where Self: 'a;
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::hugr::HugrError;
use crate::ops::handle::NodeHandle;
use crate::{Direction, Hugr, Node, Port};

use super::{check_tag, HierarchyView, HugrInternals, HugrView, RootTagged};
use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};

type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;

Expand Down Expand Up @@ -175,6 +175,8 @@ where
}
}

impl<'g, Root: NodeHandle> ExtractHugr for DescendantsGraph<'g, Root> {}

impl<'g, Root> super::HugrInternals for DescendantsGraph<'g, Root>
where
Root: NodeHandle,
Expand Down
21 changes: 20 additions & 1 deletion hugr-core/src/hugr/views/sibling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::hugr::{HugrError, HugrMut};
use crate::ops::handle::NodeHandle;
use crate::{Direction, Hugr, Node, Port};

use super::{check_tag, HierarchyView, HugrInternals, HugrView, RootTagged};
use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};

type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>;

Expand Down Expand Up @@ -205,6 +205,8 @@ where
}
}

impl<'g, Root: NodeHandle> ExtractHugr for SiblingGraph<'g, Root> {}

impl<'g, Root> HugrInternals for SiblingGraph<'g, Root>
where
Root: NodeHandle,
Expand Down Expand Up @@ -268,6 +270,8 @@ impl<'g, Root: NodeHandle> SiblingMut<'g, Root> {
}
}

impl<'g, Root: NodeHandle> ExtractHugr for SiblingMut<'g, Root> {}

impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> {
type Portgraph<'p> = FlatRegionGraph<'p> where 'g: 'p, Root: 'p;

Expand Down Expand Up @@ -484,4 +488,19 @@ mod test {
let nested_sib_mut = SiblingMut::<DataflowParentID>::try_new(&mut sib_mut, root);
assert!(nested_sib_mut.is_err());
}

#[rstest]
fn extract_hugr() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, def, _inner) = make_module_hgr()?;

let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?;
let extracted = region.extract_hugr();

let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?;

assert_eq!(region.node_count(), extracted.node_count());
assert_eq!(region.root_type(), extracted.root_type());

Ok(())
}
}