Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Flatten LeafOp #922

Merged
merged 9 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 20 additions & 53 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,6 @@ class LoadConstant(DataflowOp):
datatype: Type


class LeafOpBase(DataflowOp):
"""Simple operation that has only value inputs+outputs and (potentially) StateOrder
edges."""

op: Literal["LeafOp"] = "LeafOp"


class DFG(DataflowOp):
"""A simply nested dataflow graph."""

Expand Down Expand Up @@ -393,16 +386,11 @@ def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None:
ControlFlowOp = Conditional | TailLoop | CFG


# -----------------------------------------
# --------------- LeafOp ------------------
# -----------------------------------------


class CustomOp(LeafOpBase):
class CustomOp(DataflowOp):
"""A user-defined operation that can be downcasted by the extensions that define
it."""

lop: Literal["CustomOp"] = "CustomOp"
op: Literal["CustomOp"] = "CustomOp"
extension: ExtensionId
op_name: str
signature: tys.FunctionType = Field(default_factory=tys.FunctionType.empty)
Expand All @@ -425,10 +413,10 @@ class Config:
}


class Noop(LeafOpBase):
class Noop(DataflowOp):
"""A no-op operation."""

lop: Literal["Noop"] = "Noop"
op: Literal["Noop"] = "Noop"
ty: Type

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
Expand All @@ -438,10 +426,10 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.ty = in_types[0]


class MakeTuple(LeafOpBase):
class MakeTuple(DataflowOp):
"""An operation that packs all its inputs into a tuple."""

lop: Literal["MakeTuple"] = "MakeTuple"
op: Literal["MakeTuple"] = "MakeTuple"
tys: TypeRow = Field(default_factory=list)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
Expand All @@ -451,56 +439,30 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.tys = list(in_types)


class UnpackTuple(LeafOpBase):
class UnpackTuple(DataflowOp):
"""An operation that packs all its inputs into a tuple."""

lop: Literal["UnpackTuple"] = "UnpackTuple"
op: Literal["UnpackTuple"] = "UnpackTuple"
tys: TypeRow = Field(default_factory=list)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.tys = list(out_types)


class Tag(LeafOpBase):
class Tag(DataflowOp):
"""An operation that creates a tagged sum value from one of its variants."""

lop: Literal["Tag"] = "Tag"
op: Literal["Tag"] = "Tag"
tag: int # The variant to create.
variants: TypeRow # The variants of the sum type.


class TypeApply(LeafOpBase):
class Lift(DataflowOp):
"""Fixes some TypeParams of a polymorphic type by providing TypeArgs."""

lop: Literal["TypeApply"] = "TypeApply"
ta: "TypeApplication"


class TypeApplication(BaseModel):
"""Records details of an application of a PolyFuncType to some TypeArgs and the
result (a less-, but still potentially-, polymorphic type).
"""

input: PolyFuncType
args: list[tys.TypeTypeArg]
output: PolyFuncType

class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"description": (
"Records details of an application of a PolyFuncType to some TypeArgs "
"and the result (a less-, but still potentially-, polymorphic type)."
)
}


class LeafOp(RootModel):
"""A constant operation."""

root: CustomOp | Noop | MakeTuple | UnpackTuple | Tag | TypeApply = Field(
discriminator="lop"
)
op: Literal["Lift"] = "Lift"
type_row: TypeRow
new_extension: ExtensionId


class OpType(RootModel):
Expand All @@ -522,7 +484,12 @@ class OpType(RootModel):
| Call
| CallIndirect
| LoadConstant
| LeafOp
| CustomOp
| Noop
| MakeTuple
| UnpackTuple
| Tag
| Lift
| DFG
) = Field(discriminator="op")

Expand Down
29 changes: 14 additions & 15 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
views::SiblingSubgraph,
HugrMut,
},
ops::{Const, LeafOp},
ops::{Const, OpType},
type_row,
types::FunctionType,
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
Expand Down Expand Up @@ -44,29 +44,28 @@ pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> {
.collect()
}
/// For a given op and consts, attempt to evaluate the op.
pub fn fold_leaf_op(op: &LeafOp, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
match op {
LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]),
LeafOp::MakeTuple { .. } => {
OpType::Noop { .. } => out_row([consts.first()?.1.clone()]),
OpType::MakeTuple { .. } => {
out_row([Const::tuple(sorted_consts(consts).into_iter().cloned())])
}
LeafOp::UnpackTuple { .. } => {
OpType::UnpackTuple { .. } => {
let c = &consts.first()?.1;
let Const::Tuple { vs } = c else {
panic!("This op always takes a Tuple input.");
};
out_row(vs.iter().cloned())
}

LeafOp::Tag { tag, variants } => out_row([Const::sum(
*tag,
OpType::Tag(t) => out_row([Const::sum(
t.tag,
consts.iter().map(|(_, konst)| konst.clone()),
SumType::new(variants.clone()),
SumType::new(t.variants.clone()),
)
.unwrap()]),
LeafOp::CustomOp(_) => {
let ext_op = op.as_extension_op()?;

OpType::CustomOp(op) => {
let ext_op = op.as_ref().as_extension_op()?;
ext_op.constant_fold(consts)
}
_ => None,
Expand Down Expand Up @@ -132,7 +131,7 @@ fn fold_op(
reg: &ExtensionRegistry,
) -> Option<(SimpleReplacement, Vec<RemoveLoadConstant>)> {
// only support leaf folding for now.
let neighbour_op = hugr.get_optype(op_node).as_leaf_op()?;
let neighbour_op = hugr.get_optype(op_node);
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr
.node_inputs(op_node)
.filter_map(|in_p| {
Expand Down Expand Up @@ -214,7 +213,7 @@ mod test {
use super::*;
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::OpType;
use crate::ops::{OpType, UnpackTuple};
use crate::std_extensions::arithmetic;
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
Expand Down Expand Up @@ -242,7 +241,7 @@ mod test {
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) {
let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))];
let add_op: OpType = FloatOps::fadd.into();
let out = fold_leaf_op(add_op.as_leaf_op().unwrap(), &consts).unwrap();
let out = fold_leaf_op(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), f2c(c))]);
}
Expand All @@ -264,7 +263,7 @@ mod test {

let unpack = build
.add_dataflow_op(
LeafOp::UnpackTuple {
UnpackTuple {
tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE],
},
[tup],
Expand Down
12 changes: 6 additions & 6 deletions hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::views::HugrView;
use crate::hugr::{NodeMetadata, ValidationError};
use crate::ops::{self, LeafOp, OpTag, OpTrait, OpType};
use crate::ops::{self, MakeTuple, OpTag, OpTrait, OpType, Tag};
use crate::utils::collect_array;
use crate::{IncomingPort, Node, OutgoingPort};

Expand Down Expand Up @@ -471,25 +471,25 @@ pub trait Dataflow: Container {
}
}

/// Add a [`LeafOp::MakeTuple`] node and wire in the `values` Wires,
/// Add a [`MakeTuple`] node and wire in the `values` Wires,
/// returning the Wire corresponding to the tuple.
///
/// # Errors
///
/// This function will return an error if there is an error adding the
/// [`LeafOp::MakeTuple`] node.
/// [`MakeTuple`] node.
fn make_tuple(&mut self, values: impl IntoIterator<Item = Wire>) -> Result<Wire, BuildError> {
let values = values.into_iter().collect_vec();
let types: Result<Vec<Type>, _> = values
.iter()
.map(|&wire| self.get_wire_type(wire))
.collect();
let types = types?.into();
let make_op = self.add_dataflow_op(LeafOp::MakeTuple { tys: types }, values)?;
let make_op = self.add_dataflow_op(MakeTuple { tys: types }, values)?;
Ok(make_op.out_wire(0))
}

/// Add a [`LeafOp::Tag`] node and wire in the `value` Wire,
/// Add a [`Tag`] node and wire in the `value` Wire,
/// to make a value with Sum type, with `tag` and possible types described
/// by `variants`.
/// Returns the Wire corresponding to the Sum value.
Expand All @@ -505,7 +505,7 @@ pub trait Dataflow: Container {
values: impl IntoIterator<Item = Wire>,
) -> Result<Wire, BuildError> {
let make_op = self.add_dataflow_op(
LeafOp::Tag {
Tag {
tag,
variants: variants.into_iter().map(Into::into).collect_vec(),
},
Expand Down
19 changes: 8 additions & 11 deletions hugr/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ mod test {
Dataflow, DataflowSubContainer, Wire,
},
extension::prelude::BOOL_T,
ops::{custom::OpaqueOp, LeafOp},
ops::{custom::OpaqueOp, CustomOp},
type_row,
types::FunctionType,
};
Expand Down Expand Up @@ -278,16 +278,13 @@ mod test {

#[test]
fn with_nonlinear_and_outputs() {
let my_custom_op = LeafOp::CustomOp(
crate::ops::custom::ExternalOp::Opaque(OpaqueOp::new(
"MissingRsrc".try_into().unwrap(),
"MyOp",
"unknown op".to_string(),
vec![],
FunctionType::new(vec![QB, NAT], vec![QB]),
))
.into(),
);
let my_custom_op = CustomOp::new(crate::ops::custom::ExternalOp::Opaque(OpaqueOp::new(
"MissingRsrc".try_into().unwrap(),
"MyOp",
"unknown op".to_string(),
vec![],
FunctionType::new(vec![QB, NAT], vec![QB]),
)));
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(),
|mut f_build| {
Expand Down
16 changes: 8 additions & 8 deletions hugr/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ pub(crate) mod test {
use crate::extension::prelude::BOOL_T;
use crate::extension::{ExtensionId, EMPTY_REG};
use crate::hugr::validate::InterGraphEdgeError;
use crate::ops::{handle::NodeHandle, LeafOp, OpTag};
use crate::ops::{handle::NodeHandle, Lift, Noop, OpTag};

use crate::std_extensions::logic::test::and_op;
use crate::types::Type;
Expand Down Expand Up @@ -347,13 +347,13 @@ pub(crate) mod test {
)?;

let [i1] = f_build.input_wires_arr();
let noop = f_build.add_dataflow_op(LeafOp::Noop { ty: BIT }, [i1])?;
let noop = f_build.add_dataflow_op(Noop { ty: BIT }, [i1])?;
let i1 = noop.out_wire(0);

let mut nested =
f_build.dfg_builder(FunctionType::new(type_row![], type_row![BIT]), None, [])?;

let id = nested.add_dataflow_op(LeafOp::Noop { ty: BIT }, [i1])?;
let id = nested.add_dataflow_op(Noop { ty: BIT }, [i1])?;

let nested = nested.finish_with_outputs([id.out_wire(0)])?;

Expand All @@ -371,13 +371,13 @@ pub(crate) mod test {
)?;

let [i1] = f_build.input_wires_arr();
let noop = f_build.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1])?;
let noop = f_build.add_dataflow_op(Noop { ty: QB }, [i1])?;
let i1 = noop.out_wire(0);

let mut nested =
f_build.dfg_builder(FunctionType::new(type_row![], type_row![QB]), None, [])?;

let id_res = nested.add_dataflow_op(LeafOp::Noop { ty: QB }, [i1]);
let id_res = nested.add_dataflow_op(Noop { ty: QB }, [i1]);

// The error would anyway be caught in validation when we finish the Hugr,
// but the builder catches it earlier
Expand Down Expand Up @@ -457,7 +457,7 @@ pub(crate) mod test {
let [w] = add_ab.input_wires_arr();

let lift_a = add_ab.add_dataflow_op(
LeafOp::Lift {
Lift {
type_row: type_row![BIT],
new_extension: xa.clone(),
},
Expand All @@ -467,7 +467,7 @@ pub(crate) mod test {

let lift_b = add_ab.add_dataflow_node(
NodeType::new(
LeafOp::Lift {
Lift {
type_row: type_row![BIT],
new_extension: xb,
},
Expand All @@ -486,7 +486,7 @@ pub(crate) mod test {
let [w] = add_c.input_wires_arr();
let lift_c = add_c.add_dataflow_node(
NodeType::new(
LeafOp::Lift {
Lift {
type_row: type_row![BIT],
new_extension: xc,
},
Expand Down
4 changes: 2 additions & 2 deletions hugr/src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ mod test {
let _fdef = {
let [b1] = fbuild
.add_dataflow_op(
ops::LeafOp::Lift {
ops::Lift {
type_row: type_row![BIT],
new_extension: PRELUDE_ID,
},
Expand All @@ -147,7 +147,7 @@ mod test {
let const_val = Const::true_val();
let const_wire = loop_b.add_load_const(Const::true_val());
let lift_node = loop_b.add_dataflow_op(
ops::LeafOp::Lift {
ops::Lift {
type_row: vec![const_val.const_type().clone()].into(),
new_extension: PRELUDE_ID,
},
Expand Down
Loading