Skip to content

Commit

Permalink
fix!: Normalize input/output value/static/other ports in OpType (#783)
Browse files Browse the repository at this point in the history
Closes #777.

Now `OpTrait` clearly separates `other_port` from `static_port`, and
`OpType` has a uniform interface for `value` ports (from the signature),
`static` ports (contstants), and `other` ports.

Fixes #779. The bug was caused by the encoder ignoring the constant
input port offset, so the decoder reconnected the edge to the
`StateOrder` port instead. Now we use the proper OpType methods.

Fixes #778.

---------

Co-authored-by: Seyon Sivarajah <[email protected]>
  • Loading branch information
aborgna-q and ss2165 authored Jan 8, 2024
1 parent ca07831 commit a8f6254
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 70 deletions.
6 changes: 2 additions & 4 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,11 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
let src_port = self
.get_optype(src)
.other_output_port()
.expect("Source operation has no non-dataflow outgoing edges")
.as_outgoing()?;
.expect("Source operation has no non-dataflow outgoing edges");
let dst_port = self
.get_optype(dst)
.other_input_port()
.expect("Destination operation has no non-dataflow incoming edges")
.as_incoming()?;
.expect("Destination operation has no non-dataflow incoming edges");
self.connect(src, src_port, dst, dst_port)?;
Ok((src_port, dst_port))
}
Expand Down
25 changes: 20 additions & 5 deletions src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,10 @@ impl TryFrom<&Hugr> for SerHugrV0 {
.expect("Could not reach one of the nodes");

let find_offset = |node: Node, offset: usize, dir: Direction, hugr: &Hugr| {
let sig = hugr.signature(node).unwrap_or_default();
let offset = match offset < sig.port_count(dir) {
true => Some(offset as u16),
false => None,
};
let op = hugr.get_optype(node);
let is_value_port = offset < op.value_port_count(dir);
let is_static_input = op.static_port(dir).map_or(false, |p| p.index() == offset);
let offset = (is_value_port || is_static_input).then_some(offset as u16);
(node_rekey[&node], offset)
};

Expand Down Expand Up @@ -263,6 +262,8 @@ pub mod test {
use crate::hugr::NodeType;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{dataflow::IOTrait, Input, LeafOp, Module, Output, DFG};
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::logic::NotOp;
use crate::types::{FunctionType, Type};
use crate::OutgoingPort;
Expand Down Expand Up @@ -480,4 +481,18 @@ pub mod test {
new_hugr.validate(&EMPTY_REG).unwrap_err();
new_hugr.validate(&PRELUDE_REGISTRY).unwrap();
}

#[test]
fn constants_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
let mut builder = DFGBuilder::new(FunctionType::new(vec![], vec![FLOAT64_TYPE])).unwrap();
let w = builder.add_load_const(ConstF64::new(0.5))?;
let hugr = builder.finish_hugr_with_outputs([w], &FLOAT_OPS_REGISTRY)?;

let ser = serde_json::to_string(&hugr)?;
let deser = serde_json::from_str(&ser)?;

assert_eq!(hugr, deser);

Ok(())
}
}
162 changes: 107 additions & 55 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod module;
pub mod tag;
pub mod validate;
use crate::extension::ExtensionSet;
use crate::types::{EdgeKind, FunctionType, Type};
use crate::types::{EdgeKind, FunctionType};
use crate::{Direction, OutgoingPort, Port};
use crate::{IncomingPort, PortIndex};
use paste::paste;
Expand Down Expand Up @@ -111,33 +111,59 @@ impl Default for OpType {
}

impl OpType {
/// The edge kind for the non-dataflow or constant ports of the
/// operation, not described by the signature.
/// The edge kind for the non-dataflow ports of the operation, not described
/// by the signature.
///
/// If not None, a single extra multiport of that kind will be present on
/// the given direction.
/// If not None, a single extra port of that kind will be present on
/// the given direction after any dataflow or constant ports.
#[inline]
pub fn other_port_kind(&self, dir: Direction) -> Option<EdgeKind> {
match dir {
Direction::Incoming => self.other_input(),
Direction::Outgoing => self.other_output(),
}
}

/// The edge kind for the static ports of the operation, not described by
/// the dataflow signature.
///
/// If not None, an extra input port of that kind will be present on the
/// given direction after any dataflow ports and before any
/// [`OpType::other_port_kind`] ports.
#[inline]
pub fn static_port_kind(&self, dir: Direction) -> Option<EdgeKind> {
match dir {
Direction::Incoming => self.static_input(),
Direction::Outgoing => self.static_output(),
}
}

/// Returns the edge kind for the given port.
///
/// The result may be a value port, a static port, or a non-dataflow port.
/// See [`OpType::dataflow_signature`], [`OpType::static_port_kind`], and
/// [`OpType::other_port_kind`].
pub fn port_kind(&self, port: impl Into<Port>) -> Option<EdgeKind> {
let signature = self.dataflow_signature().unwrap_or_default();
let port: Port = port.into();
let port_as_in = port.as_incoming().ok();
let dir = port.direction();

let port_count = signature.port_count(dir);

// Dataflow ports
if port.index() < port_count {
signature.port_type(port).cloned().map(EdgeKind::Value)
} else if port_as_in.is_some() && port_as_in == self.static_input_port() {
Some(EdgeKind::Static(static_in_type(self)))
} else {
self.other_port_kind(dir)
return signature.port_type(port).cloned().map(EdgeKind::Value);
}

// Constant port
let static_kind = self.static_port_kind(dir);
if port.index() == port_count {
if let Some(kind) = static_kind {
return Some(kind);
}
}

// Non-dataflow ports
self.other_port_kind(dir)
}

/// The non-dataflow port for the operation, not described by the signature.
Expand All @@ -157,84 +183,92 @@ impl OpType {
}
}

/// The non-dataflow input port for the operation, not described by the signature.
/// See `[OpType::other_port]`.
#[inline]
pub fn other_input_port(&self) -> Option<IncomingPort> {
self.other_port(Direction::Incoming)
.map(|p| p.as_incoming().unwrap())
}

/// The non-dataflow output port for the operation, not described by the signature.
/// See `[OpType::other_port]`.
#[inline]
pub fn other_output_port(&self) -> Option<OutgoingPort> {
self.other_port(Direction::Outgoing)
.map(|p| p.as_outgoing().unwrap())
}

/// If the op has a static port, the port of that input.
///
/// See [`OpType::static_input_port`] and [`OpType::static_output_port`].
#[inline]
pub fn static_port(&self, dir: Direction) -> Option<Port> {
self.static_port_kind(dir)?;
Some(Port::new(dir, self.value_port_count(dir)))
}

/// If the op has a static input ([`Call`] and [`LoadConstant`]), the port of that input.
#[inline]
pub fn static_input_port(&self) -> Option<IncomingPort> {
self.static_port(Direction::Incoming)
.map(|p| p.as_incoming().unwrap())
}

/// If the op has a static output ([`Const`], [`FuncDefn`], [`FuncDecl`]), the port of that output.
#[inline]
pub fn static_output_port(&self) -> Option<OutgoingPort> {
self.static_port(Direction::Outgoing)
.map(|p| p.as_outgoing().unwrap())
}

/// The number of Value ports in given direction.
#[inline]
pub fn value_port_count(&self, dir: portgraph::Direction) -> usize {
self.dataflow_signature()
.map(|sig| sig.port_count(dir))
.unwrap_or(0)
}

/// The number of Value input ports.
#[inline]
pub fn value_input_count(&self) -> usize {
self.value_port_count(Direction::Incoming)
}

/// The number of Value output ports.
#[inline]
pub fn value_output_count(&self) -> usize {
self.value_port_count(Direction::Outgoing)
}

/// The non-dataflow input port for the operation, not described by the signature.
/// See `[OpType::other_port]`.
pub fn other_input_port(&self) -> Option<Port> {
self.other_port(Direction::Incoming)
}

/// The non-dataflow input port for the operation, not described by the signature.
/// See `[OpType::other_port]`.
pub fn other_output_port(&self) -> Option<Port> {
self.other_port(Direction::Outgoing)
}

/// If the op has a static input (Call and LoadConstant), the port of that input.
pub fn static_input_port(&self) -> Option<IncomingPort> {
match self {
OpType::Call(call) => Some(call.called_function_port()),
OpType::LoadConstant(l) => Some(l.constant_port()),
_ => None,
}
}

/// If the op has a static output (Const, FuncDefn, FuncDecl), the port of that output.
pub fn static_output_port(&self) -> Option<OutgoingPort> {
OpTag::StaticOutput
.is_superset(self.tag())
.then_some(0.into())
}

/// Returns the number of ports for the given direction.
#[inline]
pub fn port_count(&self, dir: Direction) -> usize {
let has_static_port = self.static_port_kind(dir).is_some();
let non_df_count = self.non_df_port_count(dir);
// if there is a static input it comes before the non_df_ports
let static_input =
(dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize;
self.value_port_count(dir) + non_df_count + static_input
self.value_port_count(dir) + has_static_port as usize + non_df_count
}

/// Returns the number of inputs ports for the operation.
#[inline]
pub fn input_count(&self) -> usize {
self.port_count(Direction::Incoming)
}

/// Returns the number of outputs ports for the operation.
#[inline]
pub fn output_count(&self) -> usize {
self.port_count(Direction::Outgoing)
}

/// Checks whether the operation can contain children nodes.
#[inline]
pub fn is_container(&self) -> bool {
self.validity_flags().allowed_children != OpTag::None
}
}

fn static_in_type(op: &OpType) -> Type {
match op {
OpType::Call(call) => Type::new_function(call.called_function_type().clone()),
OpType::LoadConstant(load) => load.constant_type().clone(),
_ => panic!("this function should not be called if the optype is not known to be Call or LoadConst.")
}
}

/// Macro used by operations that want their
/// name to be the same as their type name
macro_rules! impl_op_name {
Expand Down Expand Up @@ -288,10 +322,10 @@ pub trait OpTrait {
ExtensionSet::new()
}

/// The edge kind for the non-dataflow or constant inputs of the operation,
/// The edge kind for the non-dataflow inputs of the operation,
/// not described by the signature.
///
/// If not None, a single extra output multiport of that kind will be
/// If not None, a single extra input port of that kind will be
/// present.
fn other_input(&self) -> Option<EdgeKind> {
None
Expand All @@ -300,12 +334,30 @@ pub trait OpTrait {
/// The edge kind for the non-dataflow outputs of the operation, not
/// described by the signature.
///
/// If not None, a single extra output multiport of that kind will be
/// If not None, a single extra output port of that kind will be
/// present.
fn other_output(&self) -> Option<EdgeKind> {
None
}

/// The edge kind for a single constant input of the operation, not
/// described by the dataflow signature.
///
/// If not None, an extra input port of that kind will be present after the
/// dataflow input ports and before any [`OpTrait::other_input`] ports.
fn static_input(&self) -> Option<EdgeKind> {
None
}

/// The edge kind for a single constant output of the operation, not
/// described by the dataflow signature.
///
/// If not None, an extra output port of that kind will be present after the
/// dataflow input ports and before any [`OpTrait::other_output`] ports.
fn static_output(&self) -> Option<EdgeKind> {
None
}

/// Get the number of non-dataflow multiports.
fn non_df_port_count(&self, dir: Direction) -> usize {
match dir {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl OpTrait for Const {
<Self as StaticTag>::TAG
}

fn other_output(&self) -> Option<EdgeKind> {
fn static_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::Static(self.typ.clone()))
}
}
Expand Down
Loading

0 comments on commit a8f6254

Please sign in to comment.