Skip to content

Commit

Permalink
Use ad-hoc get_function_type
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Mar 14, 2024
1 parent b0c1f7b commit 76f18e5
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions quantinuum-hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
mod custom;

use super::{OpName, OpTrait, StaticTag};
use super::{OpName, OpParent, OpTrait, StaticTag};
use super::{OpTag, OpType};
use crate::extension::ExtensionSet;
use crate::types::{CustomType, EdgeKind, SumType, SumTypeError, Type};
use crate::types::{CustomType, EdgeKind, PolyFuncType, SumType, SumTypeError, Type};
use crate::{Hugr, HugrView};

use itertools::Itertools;
Expand Down Expand Up @@ -100,7 +100,7 @@ impl Const {
Self::Tuple { vs } => Type::new_tuple(vs.iter().map(Self::const_type).collect_vec()),
Self::Sum { sum_type, .. } => sum_type.clone().into(),
Self::Function { hugr } => {
let func_type = hugr.get_function_type().unwrap_or_else(|| {
let func_type = get_const_function_type(hugr).unwrap_or_else(|| {
panic!(
"{}",
ConstTypeError::FunctionTypeMissing {
Expand Down Expand Up @@ -143,7 +143,7 @@ impl Const {
/// Returns an error if the Hugr root node does not define a function.
pub fn function(hugr: impl Into<Hugr>) -> Result<Self, ConstTypeError> {
let hugr = hugr.into();
if hugr.get_function_type().is_none() {
if get_const_function_type(&hugr).is_none() {
Err(ConstTypeError::FunctionTypeMissing {
hugr_root_type: hugr.get_optype(hugr.root()).clone(),
})?;
Expand Down Expand Up @@ -210,7 +210,7 @@ impl OpName for Const {
match self {
Self::Extension { c: e } => format!("const:custom:{}", e.0.name()),
Self::Function { hugr: h } => {
let Some(t) = h.get_function_type() else {
let Some(t) = get_const_function_type(h) else {
panic!("HUGR root node isn't a valid function parent.");
};
format!("const:function:[{}]", t)
Expand Down Expand Up @@ -254,6 +254,27 @@ impl OpTrait for Const {
}
}

/// Returns the function type defined by the HUGR. In contrast to
/// [`Hugr::get_function_type`], this function also returns function types for
/// `FuncDecl` and `FuncDefn` operations.
///
/// For HUGRs with a [`DataflowParent`][crate::ops::DataflowParent] root
/// operation, report the signature of the inner dataflow sibling graph.
///
/// For HUGRS with a [`FuncDecl`][crate::ops::FuncDecl] or
/// [`FuncDefn`][crate::ops::FuncDefn] root operation, report the signature of
/// the function.
///
/// Otherwise, returns `None`.
fn get_const_function_type(hugr: &impl HugrView) -> Option<PolyFuncType> {
let op = hugr.get_optype(hugr.root());
match op {
OpType::FuncDecl(decl) => Some(decl.signature.clone()),
OpType::FuncDefn(defn) => Some(defn.signature.clone()),
_ => op.inner_function_type().map(PolyFuncType::from),
}
}

// [KnownTypeConst] is guaranteed to be the right type, so can be constructed
// without initial type check.
impl<T> From<T> for Const
Expand Down

0 comments on commit 76f18e5

Please sign in to comment.