diff --git a/quantinuum-hugr/src/ops/constant.rs b/quantinuum-hugr/src/ops/constant.rs index 14b5ef222..c8079e31d 100644 --- a/quantinuum-hugr/src/ops/constant.rs +++ b/quantinuum-hugr/src/ops/constant.rs @@ -2,10 +2,10 @@ mod custom; -use super::{OpName, OpTrait, StaticTag}; +use super::{OpName, OpParent, OpTrait, StaticTag}; use super::{OpTag, OpType}; use crate::extension::ExtensionSet; -use crate::types::{CustomType, EdgeKind, SumType, SumTypeError, Type}; +use crate::types::{CustomType, EdgeKind, PolyFuncType, SumType, SumTypeError, Type}; use crate::{Hugr, HugrView}; use itertools::Itertools; @@ -100,7 +100,7 @@ impl Const { Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::const_type).collect_vec()), Self::Sum { sum_type, .. } => sum_type.clone().into(), Self::Function { hugr } => { - let func_type = hugr.get_function_type().unwrap_or_else(|| { + let func_type = get_const_function_type(hugr).unwrap_or_else(|| { panic!( "{}", ConstTypeError::FunctionTypeMissing { @@ -143,7 +143,7 @@ impl Const { /// Returns an error if the Hugr root node does not define a function. pub fn function(hugr: impl Into) -> Result { let hugr = hugr.into(); - if hugr.get_function_type().is_none() { + if get_const_function_type(&hugr).is_none() { Err(ConstTypeError::FunctionTypeMissing { hugr_root_type: hugr.get_optype(hugr.root()).clone(), })?; @@ -210,7 +210,7 @@ impl OpName for Const { match self { Self::Extension { c: e } => format!("const:custom:{}", e.0.name()), Self::Function { hugr: h } => { - let Some(t) = h.get_function_type() else { + let Some(t) = get_const_function_type(h) else { panic!("HUGR root node isn't a valid function parent."); }; format!("const:function:[{}]", t) @@ -254,6 +254,27 @@ impl OpTrait for Const { } } +/// Returns the function type defined by the HUGR. In contrast to +/// [`Hugr::get_function_type`], this function also returns function types for +/// `FuncDecl` and `FuncDefn` operations. +/// +/// For HUGRs with a [`DataflowParent`][crate::ops::DataflowParent] root +/// operation, report the signature of the inner dataflow sibling graph. +/// +/// For HUGRS with a [`FuncDecl`][crate::ops::FuncDecl] or +/// [`FuncDefn`][crate::ops::FuncDefn] root operation, report the signature of +/// the function. +/// +/// Otherwise, returns `None`. +fn get_const_function_type(hugr: &impl HugrView) -> Option { + let op = hugr.get_optype(hugr.root()); + match op { + OpType::FuncDecl(decl) => Some(decl.signature.clone()), + OpType::FuncDefn(defn) => Some(defn.signature.clone()), + _ => op.inner_function_type().map(PolyFuncType::from), + } +} + // [KnownTypeConst] is guaranteed to be the right type, so can be constructed // without initial type check. impl From for Const