Skip to content

Commit

Permalink
feat: DataflowParent trait for getting inner signatures (#782)
Browse files Browse the repository at this point in the history
use trait to common up validation and `get_function_type`
Introduces new OpType level `OpParent` trait to optionally retrieve
inner function type.
  • Loading branch information
ss2165 authored Jan 8, 2024
1 parent 25d03aa commit 16dce1e
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 143 deletions.
29 changes: 6 additions & 23 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ use super::{
BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
};

use crate::ops::{self, DataflowBlock, ExitBlock, OpType};
use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
};
use crate::{hugr::views::HugrView, types::TypeRow};
use crate::{ops::handle::NodeHandle, types::Type};

use crate::Node;
use crate::{
Expand Down Expand Up @@ -150,13 +149,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
self.hugr_mut().add_node_with_parent(parent, op)
}?;

BlockBuilder::create(
self.hugr_mut(),
block_n,
tuple_sum_rows,
other_outputs,
inputs,
)
BlockBuilder::create(self.hugr_mut(), block_n)
}

/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
Expand Down Expand Up @@ -248,19 +241,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
) -> Result<(), BuildError> {
Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs))
}

fn create(
base: B,
block_n: Node,
tuple_sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
inputs: TypeRow,
) -> Result<Self, BuildError> {
// The node outputs a TupleSum before the data outputs of the block node
let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows);
let mut node_outputs = vec![tuple_sum_type];
node_outputs.extend_from_slice(&other_outputs);
let signature = FunctionType::new(inputs, TypeRow::from(node_outputs));
fn create(base: B, block_n: Node) -> Result<Self, BuildError> {
let block_op = base.get_optype(block_n).as_dataflow_block().unwrap();
let signature = block_op.inner_signature();
let inp_ex = base
.as_ref()
.get_nodetype(block_n)
Expand Down Expand Up @@ -305,7 +288,7 @@ impl BlockBuilder<Hugr> {

let base = Hugr::new(NodeType::new(op, input_extensions));
let root = base.root();
Self::create(base, root, tuple_sum_rows, other_outputs, inputs)
Self::create(base, root)
}

/// [Set outputs](BlockBuilder::set_outputs) and [finish_hugr](`BlockBuilder::finish_hugr`).
Expand Down
1 change: 1 addition & 0 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
use crate::ops::custom::{ExternalOp, OpaqueOp};
use crate::ops::dataflow::DataflowParent;
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
use crate::ops::{LeafOp, OpType};

Expand Down
17 changes: 6 additions & 11 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView};

use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE};
use crate::ops::handle::NodeHandle;
use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG};
use crate::ops::{OpName, OpParent, OpTag, OpTrait, OpType};

use crate::types::Type;
use crate::types::{EdgeKind, FunctionType, PolyFuncType};
use crate::types::{EdgeKind, FunctionType};
use crate::{Direction, IncomingPort, Node, OutgoingPort, Port};

use itertools::Either;
Expand Down Expand Up @@ -331,16 +331,11 @@ pub trait HugrView: sealed::HugrInternals {
}
}

/// For function-like HUGRs (DFG, FuncDefn, FuncDecl), report the function
/// type. Otherwise return None.
fn get_function_type(&self) -> Option<PolyFuncType> {
/// For HUGRs with a [`DataflowParent`][crate::ops::DataflowParent] root operation, report the
/// signature of the inner dataflow sibling graph. Otherwise return None.
fn get_function_type(&self) -> Option<FunctionType> {
let op = self.get_nodetype(self.root());
match &op.op {
OpType::DFG(DFG { signature }) => Some(signature.clone().into()),
OpType::FuncDecl(FuncDecl { signature, .. })
| OpType::FuncDefn(FuncDefn { signature, .. }) => Some(signature.clone()),
_ => None,
}
op.op.inner_function_type()
}

/// Return a wrapper over the view that can be used in petgraph algorithms.
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ pub(super) mod test {

assert_eq!(
region.get_function_type(),
Some(FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into())
Some(FunctionType::new_endo(type_row![NAT, QB]))
);
let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?;
assert_eq!(
inner_region.get_function_type(),
Some(FunctionType::new(type_row![NAT], type_row![NAT]).into())
Some(FunctionType::new(type_row![NAT], type_row![NAT]))
);

Ok(())
Expand Down
6 changes: 1 addition & 5 deletions src/hugr/views/sibling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,7 @@ mod test {
fn flat_mut(mut simple_dfg_hugr: Hugr) {
simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap();
let root = simple_dfg_hugr.root();
let signature = simple_dfg_hugr
.get_function_type()
.unwrap()
.instantiate(&[], &PRELUDE_REGISTRY)
.unwrap();
let signature = simple_dfg_hugr.get_function_type().unwrap();

let sib_mut = SiblingMut::<CfgID>::try_new(&mut simple_dfg_hugr, root);
assert_eq!(
Expand Down
38 changes: 35 additions & 3 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ use smol_str::SmolStr;
use enum_dispatch::enum_dispatch;

pub use constant::Const;
pub use controlflow::{Case, Conditional, DataflowBlock, ExitBlock, TailLoop, CFG};
pub use dataflow::{Call, CallIndirect, Input, LoadConstant, Output, DFG};
pub use controlflow::{BasicBlock, Case, Conditional, DataflowBlock, ExitBlock, TailLoop, CFG};
pub use dataflow::{Call, CallIndirect, DataflowParent, Input, LoadConstant, Output, DFG};
pub use leaf::LeafOp;
pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module};
pub use tag::OpTag;

#[enum_dispatch(OpTrait, OpName, ValidateOp)]
#[enum_dispatch(OpTrait, OpName, ValidateOp, OpParent)]
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
/// The concrete operation types for a node in the HUGR.
// TODO: Link the NodeHandles to the OpType.
Expand Down Expand Up @@ -368,6 +368,38 @@ pub trait OpTrait {
}
}

/// Properties of child graphs of ops, if the op has children.
#[enum_dispatch]
pub trait OpParent {
/// The inner function type of the operation, if it has a child dataflow
/// sibling graph.
fn inner_function_type(&self) -> Option<FunctionType> {
None
}
}

impl<T: DataflowParent> OpParent for T {
fn inner_function_type(&self) -> Option<FunctionType> {
Some(DataflowParent::inner_signature(self))
}
}

impl OpParent for Module {}
impl OpParent for AliasDecl {}
impl OpParent for AliasDefn {}
impl OpParent for Const {}
impl OpParent for Input {}
impl OpParent for Output {}
impl OpParent for Call {}
impl OpParent for CallIndirect {}
impl OpParent for LoadConstant {}
impl OpParent for LeafOp {}
impl OpParent for TailLoop {}
impl OpParent for CFG {}
impl OpParent for Conditional {}
impl OpParent for FuncDecl {}
impl OpParent for ExitBlock {}

#[enum_dispatch]
/// Methods for Ops to validate themselves and children
pub trait ValidateOp {
Expand Down
23 changes: 17 additions & 6 deletions src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::extension::ExtensionSet;
use crate::types::{EdgeKind, FunctionType, Type, TypeRow};
use crate::{type_row, Direction};

use super::dataflow::DataflowOpTrait;
use super::dataflow::{DataflowOpTrait, DataflowParent};
use super::OpTag;
use super::{impl_op_name, OpName, OpTrait, StaticTag};

Expand Down Expand Up @@ -152,6 +152,16 @@ impl StaticTag for ExitBlock {
const TAG: OpTag = OpTag::BasicBlockExit;
}

impl DataflowParent for DataflowBlock {
fn inner_signature(&self) -> FunctionType {
// The node outputs a TupleSum before the data outputs of the block node
let tuple_sum_type = Type::new_tuple_sum(self.tuple_sum_rows.clone());
let mut node_outputs = vec![tuple_sum_type];
node_outputs.extend_from_slice(&self.other_outputs);
FunctionType::new(self.inputs.clone(), TypeRow::from(node_outputs))
}
}

impl OpTrait for DataflowBlock {
fn description(&self) -> &str {
"A CFG basic block node"
Expand Down Expand Up @@ -253,6 +263,12 @@ impl StaticTag for Case {
const TAG: OpTag = OpTag::Case;
}

impl DataflowParent for Case {
fn inner_signature(&self) -> FunctionType {
self.signature.clone()
}
}

impl OpTrait for Case {
fn description(&self) -> &str {
"A case node inside a conditional"
Expand All @@ -277,11 +293,6 @@ impl Case {
pub fn dataflow_output(&self) -> &TypeRow {
&self.signature.output
}

/// The signature of the dataflow sibling graph contained in the [`Case`]
pub fn inner_signature(&self) -> FunctionType {
self.signature.clone()
}
}

fn tuple_sum_first(tuple_sum_row: &TypeRow, rest: &TypeRow) -> TypeRow {
Expand Down
15 changes: 14 additions & 1 deletion src/ops/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ impl LoadConstant {
}
}

/// Operations that is the parent of a dataflow graph.
pub trait DataflowParent {
/// Signature of the inner dataflow graph.
fn inner_signature(&self) -> FunctionType;
}

/// A simply nested dataflow graph.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct DFG {
Expand All @@ -285,6 +291,13 @@ pub struct DFG {
}

impl_op_name!(DFG);

impl DataflowParent for DFG {
fn inner_signature(&self) -> FunctionType {
self.signature.clone()
}
}

impl DataflowOpTrait for DFG {
const TAG: OpTag = OpTag::Dfg;

Expand All @@ -293,6 +306,6 @@ impl DataflowOpTrait for DFG {
}

fn signature(&self) -> FunctionType {
self.signature.clone()
self.inner_signature()
}
}
11 changes: 10 additions & 1 deletion src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
use smol_str::SmolStr;

use crate::types::{EdgeKind, PolyFuncType};
use crate::types::{EdgeKind, FunctionType, PolyFuncType};
use crate::types::{Type, TypeBound};

use super::dataflow::DataflowParent;
use super::StaticTag;
use super::{impl_op_name, OpTag, OpTrait};

Expand Down Expand Up @@ -43,6 +44,13 @@ impl_op_name!(FuncDefn);
impl StaticTag for FuncDefn {
const TAG: OpTag = OpTag::FuncDefn;
}

impl DataflowParent for FuncDefn {
fn inner_signature(&self) -> FunctionType {
self.signature.body().clone()
}
}

impl OpTrait for FuncDefn {
fn description(&self) -> &str {
"A function definition"
Expand Down Expand Up @@ -70,6 +78,7 @@ impl_op_name!(FuncDecl);
impl StaticTag for FuncDecl {
const TAG: OpTag = OpTag::Function;
}

impl OpTrait for FuncDecl {
fn description(&self) -> &str {
"External function declaration, linked at runtime"
Expand Down
Loading

0 comments on commit 16dce1e

Please sign in to comment.