Skip to content

Commit

Permalink
feat!: Flatten LeafOp (#922)
Browse files Browse the repository at this point in the history
Moves the variants of `LeafOp` into `OpType`. Closes #817.

It seems `Lift` was missing from the schema, and in its place there was
a `TypeApplication` struct. I replaced it with lift.

As an aside, the `CustomOp -> ExternalOp -> {ExtensionOp, OpaqueOp}`
chain seems unnecessary. We should be able to merge `CustomOp` and
`ExternalOp` into one, and get rid of multiple `as_ref`s in the code.
This doesn't modify the schema, so I'll push it as a separate PR.

---------

Co-authored-by: Seyon Sivarajah <[email protected]>
  • Loading branch information
aborgna-q and ss2165 authored Apr 12, 2024
1 parent 75e75e8 commit 3598913
Show file tree
Hide file tree
Showing 29 changed files with 460 additions and 410 deletions.
73 changes: 20 additions & 53 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,6 @@ class LoadConstant(DataflowOp):
datatype: Type


class LeafOpBase(DataflowOp):
"""Simple operation that has only value inputs+outputs and (potentially) StateOrder
edges."""

op: Literal["LeafOp"] = "LeafOp"


class DFG(DataflowOp):
"""A simply nested dataflow graph."""

Expand Down Expand Up @@ -393,16 +386,11 @@ def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None:
ControlFlowOp = Conditional | TailLoop | CFG


# -----------------------------------------
# --------------- LeafOp ------------------
# -----------------------------------------


class CustomOp(LeafOpBase):
class CustomOp(DataflowOp):
"""A user-defined operation that can be downcasted by the extensions that define
it."""

lop: Literal["CustomOp"] = "CustomOp"
op: Literal["CustomOp"] = "CustomOp"
extension: ExtensionId
op_name: str
signature: tys.FunctionType = Field(default_factory=tys.FunctionType.empty)
Expand All @@ -425,10 +413,10 @@ class Config:
}


class Noop(LeafOpBase):
class Noop(DataflowOp):
"""A no-op operation."""

lop: Literal["Noop"] = "Noop"
op: Literal["Noop"] = "Noop"
ty: Type

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
Expand All @@ -438,10 +426,10 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.ty = in_types[0]


class MakeTuple(LeafOpBase):
class MakeTuple(DataflowOp):
"""An operation that packs all its inputs into a tuple."""

lop: Literal["MakeTuple"] = "MakeTuple"
op: Literal["MakeTuple"] = "MakeTuple"
tys: TypeRow = Field(default_factory=list)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
Expand All @@ -451,56 +439,30 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.tys = list(in_types)


class UnpackTuple(LeafOpBase):
class UnpackTuple(DataflowOp):
"""An operation that packs all its inputs into a tuple."""

lop: Literal["UnpackTuple"] = "UnpackTuple"
op: Literal["UnpackTuple"] = "UnpackTuple"
tys: TypeRow = Field(default_factory=list)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.tys = list(out_types)


class Tag(LeafOpBase):
class Tag(DataflowOp):
"""An operation that creates a tagged sum value from one of its variants."""

lop: Literal["Tag"] = "Tag"
op: Literal["Tag"] = "Tag"
tag: int # The variant to create.
variants: TypeRow # The variants of the sum type.


class TypeApply(LeafOpBase):
class Lift(DataflowOp):
"""Fixes some TypeParams of a polymorphic type by providing TypeArgs."""

lop: Literal["TypeApply"] = "TypeApply"
ta: "TypeApplication"


class TypeApplication(BaseModel):
"""Records details of an application of a PolyFuncType to some TypeArgs and the
result (a less-, but still potentially-, polymorphic type).
"""

input: PolyFuncType
args: list[tys.TypeTypeArg]
output: PolyFuncType

class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"description": (
"Records details of an application of a PolyFuncType to some TypeArgs "
"and the result (a less-, but still potentially-, polymorphic type)."
)
}


class LeafOp(RootModel):
"""A constant operation."""

root: CustomOp | Noop | MakeTuple | UnpackTuple | Tag | TypeApply = Field(
discriminator="lop"
)
op: Literal["Lift"] = "Lift"
type_row: TypeRow
new_extension: ExtensionId


class OpType(RootModel):
Expand All @@ -522,7 +484,12 @@ class OpType(RootModel):
| Call
| CallIndirect
| LoadConstant
| LeafOp
| CustomOp
| Noop
| MakeTuple
| UnpackTuple
| Tag
| Lift
| DFG
) = Field(discriminator="op")

Expand Down
29 changes: 14 additions & 15 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
views::SiblingSubgraph,
HugrMut,
},
ops::{Const, LeafOp},
ops::{Const, OpType},
type_row,
types::FunctionType,
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
Expand Down Expand Up @@ -44,29 +44,28 @@ pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> {
.collect()
}
/// For a given op and consts, attempt to evaluate the op.
pub fn fold_leaf_op(op: &LeafOp, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
match op {
LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]),
LeafOp::MakeTuple { .. } => {
OpType::Noop { .. } => out_row([consts.first()?.1.clone()]),
OpType::MakeTuple { .. } => {
out_row([Const::tuple(sorted_consts(consts).into_iter().cloned())])
}
LeafOp::UnpackTuple { .. } => {
OpType::UnpackTuple { .. } => {
let c = &consts.first()?.1;
let Const::Tuple { vs } = c else {
panic!("This op always takes a Tuple input.");
};
out_row(vs.iter().cloned())
}

LeafOp::Tag { tag, variants } => out_row([Const::sum(
*tag,
OpType::Tag(t) => out_row([Const::sum(
t.tag,
consts.iter().map(|(_, konst)| konst.clone()),
SumType::new(variants.clone()),
SumType::new(t.variants.clone()),
)
.unwrap()]),
LeafOp::CustomOp(_) => {
let ext_op = op.as_extension_op()?;

OpType::CustomOp(op) => {
let ext_op = op.as_ref().as_extension_op()?;
ext_op.constant_fold(consts)
}
_ => None,
Expand Down Expand Up @@ -132,7 +131,7 @@ fn fold_op(
reg: &ExtensionRegistry,
) -> Option<(SimpleReplacement, Vec<RemoveLoadConstant>)> {
// only support leaf folding for now.
let neighbour_op = hugr.get_optype(op_node).as_leaf_op()?;
let neighbour_op = hugr.get_optype(op_node);
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr
.node_inputs(op_node)
.filter_map(|in_p| {
Expand Down Expand Up @@ -214,7 +213,7 @@ mod test {
use super::*;
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::OpType;
use crate::ops::{OpType, UnpackTuple};
use crate::std_extensions::arithmetic;
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
Expand Down Expand Up @@ -242,7 +241,7 @@ mod test {
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) {
let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))];
let add_op: OpType = FloatOps::fadd.into();
let out = fold_leaf_op(add_op.as_leaf_op().unwrap(), &consts).unwrap();
let out = fold_leaf_op(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), f2c(c))]);
}
Expand All @@ -264,7 +263,7 @@ mod test {

let unpack = build
.add_dataflow_op(
LeafOp::UnpackTuple {
UnpackTuple {
tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE],
},
[tup],
Expand Down
12 changes: 6 additions & 6 deletions hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::views::HugrView;
use crate::hugr::{NodeMetadata, ValidationError};
use crate::ops::{self, LeafOp, OpTag, OpTrait, OpType};
use crate::ops::{self, MakeTuple, OpTag, OpTrait, OpType, Tag};
use crate::utils::collect_array;
use crate::{IncomingPort, Node, OutgoingPort};

Expand Down Expand Up @@ -471,25 +471,25 @@ pub trait Dataflow: Container {
}
}

/// Add a [`LeafOp::MakeTuple`] node and wire in the `values` Wires,
/// Add a [`MakeTuple`] node and wire in the `values` Wires,
/// returning the Wire corresponding to the tuple.
///
/// # Errors
///
/// This function will return an error if there is an error adding the
/// [`LeafOp::MakeTuple`] node.
/// [`MakeTuple`] node.
fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
let values = values.into_iter().collect_vec();
let types: Result<Vec<Type>, _> = values
.iter()
.map(|&wire| self.get_wire_type(wire))
.collect();
let types = types?.into();
let make_op = self.add_dataflow_op(LeafOp::MakeTuple { tys: types }, values)?;
let make_op = self.add_dataflow_op(MakeTuple { tys: types }, values)?;
Ok(make_op.out_wire(0))
}

/// Add a [`LeafOp::Tag`] node and wire in the `value` Wire,
/// Add a [`Tag`] node and wire in the `value` Wire,
/// to make a value with Sum type, with `tag` and possible types described
/// by `variants`.
/// Returns the Wire corresponding to the Sum value.
Expand All @@ -505,7 +505,7 @@ pub trait Dataflow: Container {
values: impl IntoIterator<Item = Wire>,
) -> Result<Wire, BuildError> {
let make_op = self.add_dataflow_op(
LeafOp::Tag {
Tag {
tag,
variants: variants.into_iter().map(Into::into).collect_vec(),
},
Expand Down
19 changes: 8 additions & 11 deletions hugr/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ mod test {
Dataflow, DataflowSubContainer, Wire,
},
extension::prelude::BOOL_T,
ops::{custom::OpaqueOp, LeafOp},
ops::{custom::OpaqueOp, CustomOp},
type_row,
types::FunctionType,
};
Expand Down Expand Up @@ -278,16 +278,13 @@ mod test {

#[test]
fn with_nonlinear_and_outputs() {
let my_custom_op = LeafOp::CustomOp(
crate::ops::custom::ExternalOp::Opaque(OpaqueOp::new(
"MissingRsrc".try_into().unwrap(),
"MyOp",
"unknown op".to_string(),
vec![],
FunctionType::new(vec![QB, NAT], vec![QB]),
))
.into(),
);
let my_custom_op = CustomOp::new(crate::ops::custom::ExternalOp::Opaque(OpaqueOp::new(
"MissingRsrc".try_into().unwrap(),
"MyOp",
"unknown op".to_string(),
vec![],
FunctionType::new(vec![QB, NAT], vec![QB]),
)));
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(),
|mut f_build| {
Expand Down
16 changes: 8 additions & 8 deletions hugr/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ pub(crate) mod test {
use crate::extension::prelude::BOOL_T;
use crate::extension::{ExtensionId, EMPTY_REG};
use crate::hugr::validate::InterGraphEdgeError;
use crate::ops::{handle::NodeHandle, LeafOp, OpTag};
use crate::ops::{handle::NodeHandle, Lift, Noop, OpTag};

use crate::std_extensions::logic::test::and_op;
use crate::types::Type;
Expand Down Expand Up @@ -347,13 +347,13 @@ pub(crate) mod test {
)?;

let [i1] = f_build.input_wires_arr();
let noop = f_build.add_dataflow_op(LeafOp::Noop { ty: BIT }, [i1])?;
let noop = f_build.add_dataflow_op(Noop { ty: BIT }, [i1])?;
let i1 = noop.out_wire(0);

let mut nested =
f_build.dfg_builder(FunctionType::new(type_row![], type_row![BIT]), None, [])?;

let id = nested.add_dataflow_op(LeafOp::Noop { ty: BIT }, [i1])?;
let id = nested.add_dataflow_op(Noop { ty: BIT }, [i1])?;

let nested = nested.finish_with_outputs([id.out_wire(0)])?;

Expand All @@ -371,13 +371,13 @@ pub(crate) mod test {
)?;

let [i1] = f_build.input_wires_arr();
let noop = f_build.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1])?;
let noop = f_build.add_dataflow_op(Noop { ty: QB }, [i1])?;
let i1 = noop.out_wire(0);

let mut nested =
f_build.dfg_builder(FunctionType::new(type_row![], type_row![QB]), None, [])?;

let id_res = nested.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1]);
let id_res = nested.add_dataflow_op(Noop { ty: QB }, [i1]);

// The error would anyway be caught in validation when we finish the Hugr,
// but the builder catches it earlier
Expand Down Expand Up @@ -457,7 +457,7 @@ pub(crate) mod test {
let [w] = add_ab.input_wires_arr();

let lift_a = add_ab.add_dataflow_op(
LeafOp::Lift {
Lift {
type_row: type_row![BIT],
new_extension: xa.clone(),
},
Expand All @@ -467,7 +467,7 @@ pub(crate) mod test {

let lift_b = add_ab.add_dataflow_node(
NodeType::new(
LeafOp::Lift {
Lift {
type_row: type_row![BIT],
new_extension: xb,
},
Expand All @@ -486,7 +486,7 @@ pub(crate) mod test {
let [w] = add_c.input_wires_arr();
let lift_c = add_c.add_dataflow_node(
NodeType::new(
LeafOp::Lift {
Lift {
type_row: type_row![BIT],
new_extension: xc,
},
Expand Down
4 changes: 2 additions & 2 deletions hugr/src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ mod test {
let _fdef = {
let [b1] = fbuild
.add_dataflow_op(
ops::LeafOp::Lift {
ops::Lift {
type_row: type_row![BIT],
new_extension: PRELUDE_ID,
},
Expand All @@ -147,7 +147,7 @@ mod test {
let const_val = Const::true_val();
let const_wire = loop_b.add_load_const(Const::true_val());
let lift_node = loop_b.add_dataflow_op(
ops::LeafOp::Lift {
ops::Lift {
type_row: vec![const_val.const_type().clone()].into(),
new_extension: PRELUDE_ID,
},
Expand Down
Loading

0 comments on commit 3598913

Please sign in to comment.