Skip to content

Commit

Permalink
feat: No polymorphic closures (#906)
Browse files Browse the repository at this point in the history
* Type::Function now stores only a FunctionType, not a PolyFuncType
* PolyFuncType remains, for OpDef's and FuncDefn/Decl's
* EdgeKind::Static now replaced by EdgeKind::Const (a type)
EdgeKind::Static (a PolyFuncType)
* Remove LeafOp::TypeApply, repurpose validation code onto Call
* Thus, progressively remove all `impl Substitution`s except for `struct
SubstValues`, which can become Substitution
* Update spec, introducing "Static" and "Dataflow" edge kinds as broader
classes of the other edge kinds, and polymorphic "type schemes" vs
monomorphic "Function types".
* Update serialization schema and add roundtrip test of a Noop operating
on a value of function type

fixes #904

This should enable resolving #788 and related capture/closure issues if
we forbid edges into a FuncDefn from outside (@doug-q)

BREAKING CHANGE: EdgeKind::{Static -> Const}, add new
EdgeKind::Function, Type contains only monomorphic functions, remove
TypeApply.
  • Loading branch information
acl-cqc authored Apr 9, 2024
1 parent 100e2cc commit b05dd6b
Show file tree
Hide file tree
Showing 20 changed files with 321 additions and 720 deletions.
21 changes: 10 additions & 11 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class SumType(RootModel):


class Variable(BaseModel):
"""A type variable identified by a de Bruijn index."""
"""A type variable identified by an index into the array of TypeParams."""

t: Literal["V"] = "V"
i: int
Expand All @@ -189,6 +189,8 @@ class FunctionType(BaseModel):
"""A graph encoded as a value. It contains a concrete signature and a set of
required resources."""

t: Literal["G"] = "G"

input: "TypeRow" # Value inputs of the function.
output: "TypeRow" # Value outputs of the function.
# The extension requirements which are added by the operation
Expand All @@ -209,15 +211,12 @@ class Config:


class PolyFuncType(BaseModel):
"""A graph encoded as a value. It contains a concrete signature and a set of
required resources."""

t: Literal["G"] = "G"
"""A polymorphic type scheme, i.e. of a FuncDecl, FuncDefn or OpDef.
(Nodes/operations in the Hugr are not polymorphic.)"""

# The declared type parameters, i.e., these must be instantiated with the same
# number of TypeArgs before the function can be called. Note that within the body,
# variable (DeBruijn) index 0 is element 0 of this array, i.e. the variables are
# bound from right to left.
# number of TypeArgs before the function can be called. This defines the indices
# used for variables within the body.
params: list[TypeParam]

# Template for the function. May contain variables up to length of `params`
Expand All @@ -231,8 +230,8 @@ class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"description": (
"A graph encoded as a value. It contains a concrete signature and "
"a set of required resources."
"A polymorphic type scheme, i.e. of a FuncDecl, FuncDefn or OpDef. "
"(Nodes/operations in the Hugr are not polymorphic.)"
)
}

Expand Down Expand Up @@ -279,7 +278,7 @@ class Type(RootModel):
"""A HUGR type."""

root: Annotated[
Qubit | Variable | USize | PolyFuncType | Array | SumType | Opaque,
Qubit | Variable | USize | FunctionType | Array | SumType | Opaque,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="t")

Expand Down
25 changes: 13 additions & 12 deletions hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ use crate::ops;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::types::type_param::{check_type_args, TypeArgError};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{
check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound, TypeName,
};
use crate::types::FunctionType;
use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName};

#[allow(dead_code)]
mod infer;
Expand Down Expand Up @@ -163,14 +162,16 @@ pub enum SignatureError {
/// A type variable that was used has not been declared
#[error("Type variable {idx} was not declared ({num_decls} in scope)")]
FreeTypeVar { idx: usize, num_decls: usize },
/// The type stored in a [LeafOp::TypeApply] is not what we compute from the
/// [ExtensionRegistry].
/// The result of the type application stored in a [Call]
/// is not what we get by applying the type-args to the polymorphic function
///
/// [LeafOp::TypeApply]: crate::ops::LeafOp::TypeApply
#[error("Incorrect result of type application - cached {cached} but expected {expected}")]
TypeApplyIncorrectCache {
cached: PolyFuncType,
expected: PolyFuncType,
/// [Call]: crate::ops::dataflow::Call
#[error(
"Incorrect result of type application in Call - cached {cached} but expected {expected}"
)]
CallIncorrectlyAppliesType {
cached: FunctionType,
expected: FunctionType,
},
}

Expand Down Expand Up @@ -418,7 +419,7 @@ impl ExtensionSet {

/// Adds a type var (which must have been declared as a [TypeParam::Extensions]) to this set
pub fn insert_type_var(&mut self, idx: usize) {
// Represent type vars as string representation of DeBruijn index.
// Represent type vars as string representation of variable index.
// This is not a legal IdentList or ExtensionId so should not conflict.
self.0
.insert(ExtensionId::new_unchecked(idx.to_string().as_str()));
Expand Down Expand Up @@ -491,7 +492,7 @@ impl ExtensionSet {
.try_for_each(|var_idx| check_typevar_decl(params, var_idx, &TypeParam::Extensions))
}

pub(crate) fn substitute(&self, t: &impl Substitution) -> Self {
pub(crate) fn substitute(&self, t: &Substitution) -> Self {
Self::from_iter(self.0.iter().flat_map(|e| match as_typevar(e) {
None => vec![e.clone()],
Some(i) => match t.apply_var(i, &TypeParam::Extensions) {
Expand Down
9 changes: 3 additions & 6 deletions hugr/src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,9 @@ impl UnificationContext {
let sig = hugr.get_nodetype(tgt_node).op();
// Incoming ports with an edge that should mean equal extension reqs
for port in hugr.node_inputs(tgt_node).filter(|src_port| {
matches!(
sig.port_kind(*src_port),
Some(EdgeKind::Value(_))
| Some(EdgeKind::Static(_))
| Some(EdgeKind::ControlFlow)
)
let kind = sig.port_kind(*src_port);
kind.as_ref().is_some_and(EdgeKind::is_static)
|| matches!(kind, Some(EdgeKind::Value(_)) | Some(EdgeKind::ControlFlow))
}) {
let m_tgt = *self
.extensions
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ impl OpDef {
// TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
// for both type scheme and custom binary
if let SignatureFunc::TypeScheme(ts) = &self.signature_func {
ts.poly_func.validate(exts, &[])?;
ts.poly_func.validate(exts)?;
}
Ok(())
}
Expand Down
16 changes: 9 additions & 7 deletions hugr/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub enum NewEdgeKind {
/// The target port
tgt_pos: IncomingPort,
},
/// An [EdgeKind::Static] edge
/// An [EdgeKind::Const] or [EdgeKind::Function] edge
Static {
/// The source port
src_pos: OutgoingPort,
Expand Down Expand Up @@ -90,9 +90,10 @@ impl NewEdgeSpec {
NewEdgeKind::Value { src_pos, .. } => {
matches!(optype.port_kind(src_pos), Some(EdgeKind::Value(_)))
}
NewEdgeKind::Static { src_pos, .. } => {
matches!(optype.port_kind(src_pos), Some(EdgeKind::Static(_)))
}
NewEdgeKind::Static { src_pos, .. } => optype
.port_kind(src_pos)
.as_ref()
.is_some_and(EdgeKind::is_static),
NewEdgeKind::ControlFlow { src_pos } => {
matches!(optype.port_kind(src_pos), Some(EdgeKind::ControlFlow))
}
Expand All @@ -107,9 +108,10 @@ impl NewEdgeSpec {
NewEdgeKind::Value { tgt_pos, .. } => {
matches!(optype.port_kind(tgt_pos), Some(EdgeKind::Value(_)))
}
NewEdgeKind::Static { tgt_pos, .. } => {
matches!(optype.port_kind(tgt_pos), Some(EdgeKind::Static(_)))
}
NewEdgeKind::Static { tgt_pos, .. } => optype
.port_kind(tgt_pos)
.as_ref()
.is_some_and(EdgeKind::is_static),
NewEdgeKind::ControlFlow { .. } => matches!(
optype.port_kind(IncomingPort::from(0)),
Some(EdgeKind::ControlFlow)
Expand Down
11 changes: 11 additions & 0 deletions hugr/src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,17 @@ pub mod test {
Ok(())
}

#[test]
fn function_type() -> Result<(), Box<dyn std::error::Error>> {
let fn_ty = Type::new_function(FunctionType::new_endo(type_row![BOOL_T]));
let mut bldr = DFGBuilder::new(FunctionType::new_endo(vec![fn_ty.clone()]))?;
let op = bldr.add_dataflow_op(LeafOp::Noop { ty: fn_ty }, bldr.input_wires())?;
let h = bldr.finish_prelude_hugr_with_outputs(op.outputs())?;

check_hugr_roundtrip(&h);
Ok(())
}

#[test]
fn hierarchy_order() -> Result<(), Box<dyn std::error::Error>> {
let mut hugr = closed_dfg_root_hugr(FunctionType::new(vec![QB], vec![QB]));
Expand Down
98 changes: 37 additions & 61 deletions hugr/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::ops::custom::{resolve_opaque_op, ExternalOp};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::{FuncDefn, OpTag, OpTrait, OpType, ValidateOp};
use crate::types::type_param::TypeParam;
use crate::types::{EdgeKind, Type};
use crate::types::EdgeKind;
use crate::{Direction, Hugr, Node, Port};

use super::views::{HierarchyView, HugrView, SiblingGraph};
Expand Down Expand Up @@ -219,17 +219,8 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
return Ok(());
}

match &port_kind {
EdgeKind::Value(ty) => ty
.validate(self.extension_registry, var_decls)
.map_err(|cause| ValidationError::SignatureError { node, cause })?,
// Static edges must *not* refer to type variables declared by enclosing FuncDefns
// as these are only types at runtime.
EdgeKind::Static(ty) => ty
.validate(self.extension_registry, &[])
.map_err(|cause| ValidationError::SignatureError { node, cause })?,
_ => (),
}
self.validate_port_kind(&port_kind, var_decls)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;

let mut link_cnt = 0;
for (_, link) in links {
Expand Down Expand Up @@ -271,6 +262,21 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
Ok(())
}

fn validate_port_kind(
&self,
port_kind: &EdgeKind,
var_decls: &[TypeParam],
) -> Result<(), SignatureError> {
match &port_kind {
EdgeKind::Value(ty) => ty.validate(self.extension_registry, var_decls),
// Static edges must *not* refer to type variables declared by enclosing FuncDefns
// as these are only types at runtime.
EdgeKind::Const(ty) => ty.validate(self.extension_registry, &[]),
EdgeKind::Function(pf) => pf.validate(self.extension_registry),
_ => Ok(()),
}
}

/// Check operation-specific constraints.
///
/// These are flags defined for each operation type as an [`OpValidityFlags`] object.
Expand Down Expand Up @@ -409,45 +415,26 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
from_optype: &OpType,
to: Node,
to_offset: Port,
) -> Result<(), ValidationError> {
) -> Result<(), InterGraphEdgeError> {
let from_parent = self
.hugr
.get_parent(from)
.expect("Root nodes cannot have ports");
let to_parent = self.hugr.get_parent(to);
let local = Some(from_parent) == to_parent;

let is_static = match from_optype.port_kind(from_offset).unwrap() {
EdgeKind::Static(typ) => {
if !(OpTag::Const.is_superset(from_optype.tag())
|| OpTag::Function.is_superset(from_optype.tag()))
{
return Err(InterGraphEdgeError::InvalidConstSrc {
from,
from_offset,
typ,
}
.into());
};
true
}
ty => {
if !local && !matches!(&ty, EdgeKind::Value(t) if t.copyable()) {
return Err(InterGraphEdgeError::NonCopyableData {
from,
from_offset,
to,
to_offset,
ty,
}
.into());
}
false
}
};
if local {
return Ok(());
let edge_kind = from_optype.port_kind(from_offset).unwrap();
if Some(from_parent) == to_parent {
return Ok(()); // Local edge
}
let is_static = edge_kind.is_static();
if !is_static && !matches!(&edge_kind, EdgeKind::Value(t) if t.copyable()) {
return Err(InterGraphEdgeError::NonCopyableData {
from,
from_offset,
to,
to_offset,
ty: edge_kind,
});
};

// To detect either external or dominator edges, we traverse the ancestors
// of the target until we find either `from_parent` (in the external
Expand Down Expand Up @@ -489,8 +476,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
to,
to_offset,
ancestor_parent_op: ancestor_parent_op.clone(),
}
.into());
});
}

// Check domination
Expand All @@ -513,8 +499,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
to_offset,
from_parent,
ancestor,
}
.into());
});
}

return Ok(());
Expand All @@ -526,8 +511,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
from_offset,
to,
to_offset,
}
.into())
})
}

/// Validates that TypeArgs are valid wrt the [ExtensionRegistry] and that nodes
Expand Down Expand Up @@ -573,8 +557,8 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
}
}
}
OpType::LeafOp(crate::ops::LeafOp::TypeApply { ta }) => {
ta.validate(self.extension_registry)
OpType::Call(c) => {
c.validate(self.extension_registry)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;
}
_ => (),
Expand Down Expand Up @@ -768,14 +752,6 @@ pub enum InterGraphEdgeError {
from_parent: Node,
ancestor: Node,
},
#[error(
"Const edge comes from an invalid node type: {from:?} ({from_offset:?}). Edge type: {typ}"
)]
InvalidConstSrc {
from: Node,
from_offset: Port,
typ: Type,
},
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit b05dd6b

Please sign in to comment.