From 1fac65bc54cc5bb42b27cb152b0f437421513038 Mon Sep 17 00:00:00 2001 From: doug-q <141026920+doug-q@users.noreply.github.com> Date: Mon, 16 Oct 2023 10:47:33 +0100 Subject: [PATCH] feat: Add accessors for node index and const values (#605) --- src/hugr.rs | 18 +++++++++++++++--- src/std_extensions/arithmetic/float_types.rs | 5 +++++ src/std_extensions/arithmetic/int_types.rs | 20 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/hugr.rs b/src/hugr.rs index ec5926fb9..ec44afa46 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -19,7 +19,7 @@ pub use ident::{IdentList, InvalidIdentifier}; pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError}; use portgraph::multiportgraph::MultiPortGraph; -use portgraph::{Hierarchy, NodeIndex, PortMut, UnmanagedDenseMap}; +use portgraph::{Hierarchy, PortMut, UnmanagedDenseMap}; use thiserror::Error; #[cfg(feature = "pyo3")] @@ -214,6 +214,12 @@ pub trait PortIndex { fn index(self) -> usize; } +/// A trait for getting the index of a node. +pub trait NodeIndex { + /// Returns the index of the node. + fn index(self) -> usize; +} + /// A port in the incoming direction. #[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Default, Debug)] pub struct IncomingPort { @@ -355,7 +361,7 @@ impl Hugr { source = ordered[source.index.index()]; } - let target: Node = NodeIndex::new(position).into(); + let target: Node = portgraph::NodeIndex::new(position).into(); if target != source { self.graph.swap_nodes(target.index, source.index); self.op_types.swap(target.index, source.index); @@ -363,7 +369,7 @@ impl Hugr { rekey(source, target); } } - self.root = NodeIndex::new(0); + self.root = portgraph::NodeIndex::new(0); // Finish by compacting the copy nodes. // The operation nodes will be left in place. @@ -494,6 +500,12 @@ impl TryFrom for OutgoingPort { } } +impl NodeIndex for Node { + fn index(self) -> usize { + self.index.into() + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] /// A DataFlow wire, defined by a Value-kind output port of a node // Stores node and offset to output port diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index a943788f3..5413b6183 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -42,6 +42,11 @@ impl ConstF64 { pub fn new(value: f64) -> Self { Self { value } } + + /// Returns the value of the constant + pub fn value(&self) -> f64 { + self.value + } } impl KnownTypeConst for ConstF64 { diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 207f1065c..87f88f2d0 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -103,6 +103,16 @@ impl ConstIntU { } Ok(Self { log_width, value }) } + + /// Returns the value of the constant + pub fn value(&self) -> u64 { + self.value + } + + /// Returns the number of bits of the constant + pub fn log_width(&self) -> u8 { + self.log_width + } } impl ConstIntS { @@ -123,6 +133,16 @@ impl ConstIntS { } Ok(Self { log_width, value }) } + + /// Returns the value of the constant + pub fn value(&self) -> i64 { + self.value + } + + /// Returns the number of bits of the constant + pub fn log_width(&self) -> u8 { + self.log_width + } } #[typetag::serde]