Skip to content

Commit

Permalink
Merge branch 'main' of gh:CQCL-DEV/hugr into feat/const-serialisation2
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 9, 2024
2 parents ddc485e + b0eb9d3 commit 5f6c22b
Show file tree
Hide file tree
Showing 19 changed files with 2,892 additions and 156 deletions.
74 changes: 58 additions & 16 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,15 @@ mod test {

use super::*;
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::PRELUDE;
use crate::ops::UnpackTuple;
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::{OpType, UnpackTuple};
use crate::std_extensions::arithmetic;
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use crate::std_extensions::logic::{self, NaryLogic};
use crate::std_extensions::logic::{self, NaryLogic, NotOp};
use crate::utils::test::assert_fully_folded;

use rstest::rstest;

Expand Down Expand Up @@ -274,7 +275,7 @@ mod test {
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();
let to_int = build
.add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs())
.add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs())
.unwrap();

let reg = ExtensionRegistry::try_new([
Expand Down Expand Up @@ -362,19 +363,60 @@ mod test {
Ok(())
}

fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
// check the hugr just loads and returns a single const
let mut node_count = 0;
#[test]
fn test_fold_and() {
// pseudocode:
// x0, x1 := bool(true), bool(true)
// x2 := and(x0, x1)
// output x2 == true;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_load_const(Value::true_val());
let x2 = build
.add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1])
.unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}

for node in h.children(h.root()) {
let op = h.get_optype(node);
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c.value() == expected_value => node_count += 1,
_ => panic!("unexpected op: {:?}", op),
}
}
#[test]
fn test_fold_or() {
// pseudocode:
// x0, x1 := bool(true), bool(false)
// x2 := or(x0, x1)
// output x2 == true;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_load_const(Value::false_val());
let x2 = build
.add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1])
.unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}

assert_eq!(node_count, 4);
#[test]
fn test_fold_not() {
// pseudocode:
// x0 := bool(true)
// x1 := not(x0)
// output x1 == false;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::false_val();
assert_fully_folded(&h, &expected);
}
}
2 changes: 1 addition & 1 deletion hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mod const_fold;
pub mod prelude;
pub mod simple_op;
pub mod validate;
pub use const_fold::{ConstFold, ConstFoldResult};
pub use const_fold::{ConstFold, ConstFoldResult, Folder};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

pub mod declarative;
Expand Down
16 changes: 16 additions & 0 deletions hugr/src/extension/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use std::fmt::Formatter;

use std::fmt::Debug;

use crate::ops::Value;
use crate::types::TypeArg;

use crate::IncomingPort;
use crate::OutgoingPort;

use crate::ops;
Expand Down Expand Up @@ -45,3 +47,17 @@ where
self(consts)
}
}

type FoldFn = dyn Fn(&[TypeArg], &[(IncomingPort, Value)]) -> ConstFoldResult + Send + Sync;

/// Type holding a boxed const-folding function.
pub struct Folder {
/// Const-folding function.
pub folder: Box<FoldFn>,
}

impl ConstFold for Folder {
fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
(self.folder)(type_args, consts)
}
}
4 changes: 2 additions & 2 deletions hugr/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ pub const USIZE_T: Type = Type::new_extension(USIZE_CUSTOM_T);
pub const BOOL_T: Type = Type::new_unit_sum(2);

/// Initialize a new array of element type `element_ty` of length `size`
pub fn array_type(size: TypeArg, element_ty: Type) -> Type {
pub fn array_type(size: impl Into<TypeArg>, element_ty: Type) -> Type {
let array_def = PRELUDE.get_type("array").unwrap();
let custom_t = array_def
.instantiate(vec![size, TypeArg::Type { ty: element_ty }])
.instantiate(vec![size.into(), element_ty.into()])
.unwrap();
Type::new_extension(custom_t)
}
Expand Down
10 changes: 8 additions & 2 deletions hugr/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(feature = "extension_inference")]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand Down Expand Up @@ -196,8 +198,12 @@ impl Hugr {
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
resolve_extension_ops(self, extension_registry)?;
self.infer_extensions()?;
self.validate(extension_registry)?;
self.validate_no_extensions(extension_registry)?;
#[cfg(feature = "extension_inference")]
{
self.infer_extensions()?;
self.validate_extensions(HashMap::new())?;
}
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions hugr/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ mod test {
)?;
let [a] = inner.input_wires_arr();
let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_width(6), [a, c1])?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
inner.finish_with_outputs(a1.outputs())?
};
let [a1] = inner.outputs_arr();

let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_width(6), [a1, b])?;
let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs(), &reg)?;

// Sanity checks
Expand Down
78 changes: 50 additions & 28 deletions hugr/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ struct ValidationContext<'a, 'b> {
hugr: &'a Hugr,
/// Dominator tree for each CFG region, using the container node as index.
dominators: HashMap<Node, Dominators<Node>>,
/// Context for the extension validation.
#[allow(dead_code)]
extension_validator: ExtensionValidator,
/// Registry of available Extensions
extension_registry: &'b ExtensionRegistry,
}
Expand All @@ -48,7 +45,51 @@ impl Hugr {
/// TODO: Add a version of validation which allows for open extension
/// variables (see github issue #457)
pub fn validate(&self, extension_registry: &ExtensionRegistry) -> Result<(), ValidationError> {
self.validate_with_extension_closure(HashMap::new(), extension_registry)
#[cfg(feature = "extension_inference")]
self.validate_with_extension_closure(HashMap::new(), extension_registry)?;
#[cfg(not(feature = "extension_inference"))]
self.validate_no_extensions(extension_registry)?;
Ok(())
}

/// Check the validity of the HUGR, but don't check consistency of extension
/// requirements between connected nodes or between parents and children.
pub fn validate_no_extensions(
&self,
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
let mut validator = ValidationContext::new(self, extension_registry);
validator.validate()
}

/// Validate extensions on the input and output edges of nodes. Check that
/// the target ends of edges require the extensions from the sources, and
/// check extension deltas from parent nodes are reflected in their children
pub fn validate_extensions(&self, closure: ExtensionSolution) -> Result<(), ValidationError> {
let validator = ExtensionValidator::new(self, closure);
for src_node in self.nodes() {
let node_type = self.get_nodetype(src_node);

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.get_io(src_node) {
validator.validate_io_extensions(src_node, input, output)?;
}
}

for src_port in self.node_outputs(src_node) {
for (tgt_node, tgt_port) in self.linked_inputs(src_node, src_port) {
validator.check_extensions_compatible(
&(src_node, src_port.into()),
&(tgt_node, tgt_port.into()),
)?;
}
}
}
Ok(())
}

/// Check the validity of a hugr, taking an argument of a closure for the
Expand All @@ -58,8 +99,10 @@ impl Hugr {
closure: ExtensionSolution,
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
let mut validator = ValidationContext::new(self, closure, extension_registry);
validator.validate()
let mut validator = ValidationContext::new(self, extension_registry);
validator.validate()?;
self.validate_extensions(closure)?;
Ok(())
}
}

Expand All @@ -68,15 +111,10 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
// Allow unused "extension_closure" variable for when
// the "extension_inference" feature is disabled.
#[allow(unused_variables)]
pub fn new(
hugr: &'a Hugr,
extension_closure: ExtensionSolution,
extension_registry: &'b ExtensionRegistry,
) -> Self {
pub fn new(hugr: &'a Hugr, extension_registry: &'b ExtensionRegistry) -> Self {
Self {
hugr,
dominators: HashMap::new(),
extension_validator: ExtensionValidator::new(hugr, extension_closure),
extension_registry,
}
}
Expand Down Expand Up @@ -176,18 +214,6 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
// Secondly that the node has correct children
self.validate_children(node, node_type)?;

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
#[cfg(feature = "extension_inference")]
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.hugr.get_io(node) {
self.extension_validator
.validate_io_extensions(node, input, output)?;
}
}

Ok(())
}

Expand Down Expand Up @@ -247,10 +273,6 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
let other_node: Node = self.hugr.graph.port_node(link).unwrap().into();
let other_offset = self.hugr.graph.port_offset(link).unwrap().into();

#[cfg(feature = "extension_inference")]
self.extension_validator
.check_extensions_compatible(&(node, port), &(other_node, other_offset))?;

let other_op = self.hugr.get_optype(other_node);
let Some(other_kind) = other_op.port_kind(other_offset) else {
panic!("The number of ports in {other_node} does not match the operation definition. This should have been caught by `validate_node`.");
Expand Down
Loading

0 comments on commit 5f6c22b

Please sign in to comment.