From 6628b0f03c27b5b86202d630c9cf0b0852cfcbeb Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Tue, 14 May 2024 10:54:36 +0100 Subject: [PATCH] wip --- hugr/src/algorithm.rs | 32 +--------- hugr/src/algorithm/const_fold.rs | 98 +++++++++++------------------ hugr/src/algorithm/verify.rs | 102 +++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 92 deletions(-) create mode 100644 hugr/src/algorithm/verify.rs diff --git a/hugr/src/algorithm.rs b/hugr/src/algorithm.rs index 403632ecd1..7010c395e5 100644 --- a/hugr/src/algorithm.rs +++ b/hugr/src/algorithm.rs @@ -4,34 +4,4 @@ pub mod const_fold; mod half_node; pub mod merge_bbs; pub mod nest_cfgs; - -#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)] -/// A type for algorithms to take as configuration, specifying how much -/// verification they should do. Algorithms that accept this configuration -/// should at least verify that input HUGRs are valid, and that output HUGRs are -/// valid. -/// -/// The default level is `None` because verification can be expensive. -pub enum VerifyLevel { - /// Do no verification. - None, - /// Verify using [HugrView::validate_no_extensions]. This is useful when you - /// do not expect valid Extension annotations on Nodes. - /// - /// [HugrView::validate_no_extensions]: crate::HugrView::validate_no_extensions - WithoutExtensions, - /// Verify using [HugrView::validate]. - /// - /// [HugrView::validate]: crate::HugrView::validate - WithExtensions, -} - -impl Default for VerifyLevel { - fn default() -> Self { - if cfg!(test) { - Self::WithoutExtensions - } else { - Self::None - } - } -} +pub mod verify; diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index cf4f617499..7cff27b465 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -5,7 +5,8 @@ use std::collections::{BTreeSet, HashMap}; use itertools::Itertools; use thiserror::Error; -use crate::hugr::{SimpleReplacementError, ValidationError}; +use crate::algorithm::verify::{VerifyError, VerifyLevel}; +use crate::hugr::SimpleReplacementError; use crate::types::SumType; use crate::Direction; use crate::{ @@ -22,91 +23,66 @@ use crate::{ Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; -use super::VerifyLevel; - #[derive(Error, Debug)] #[allow(missing_docs)] pub enum ConstFoldError { - #[error("Failed to verify {label} HUGR: {err}")] - VerifyError { - label: String, - #[source] - err: ValidationError, - }, #[error(transparent)] - SimpleReplaceError(#[from] SimpleReplacementError), -} - -impl ConstFoldError { - fn verify_err(label: impl Into, err: ValidationError) -> Self { - Self::VerifyError { - label: label.into(), - err, - } - } + SimpleReplacementError(#[from] SimpleReplacementError), + #[error(transparent)] + VerifyError(#[from] VerifyError), } #[derive(Debug, Clone, Copy, Default)] /// A configuration for the Constant Folding pass. -pub struct ConstFoldConfig { +pub struct ConstantFoldPass { verify: VerifyLevel, } -impl ConstFoldConfig { +impl ConstantFoldPass { /// Create a new `ConstFoldConfig` with default configuration. pub fn new() -> Self { Self::default() } /// Build a `ConstFoldConfig` with the given [VerifyLevel]. - pub fn with_verify(mut self, verify: VerifyLevel) -> Self { + pub fn verify_level(mut self, verify: VerifyLevel) -> Self { self.verify = verify; self } - fn verify_impl( + /// Run the Constant Folding pass. + pub fn run( &self, - label: &str, - h: &impl HugrView, + hugr: &mut H, reg: &ExtensionRegistry, ) -> Result<(), ConstFoldError> { - match self.verify { - VerifyLevel::None => Ok(()), - VerifyLevel::WithoutExtensions => h.validate_no_extensions(reg), - VerifyLevel::WithExtensions => h.validate(reg), - } - .map_err(|err| ConstFoldError::verify_err(label, err)) - } - - /// Run the Constant Folding pass. - pub fn run(&self, h: &mut impl HugrMut, reg: &ExtensionRegistry) -> Result<(), ConstFoldError> { - self.verify_impl("input", h, reg)?; - loop { - // We can only safely apply a single replacement. Applying a - // replacement removes nodes and edges which may be referenced by - // further replacements returned by find_consts. Even worse, if we - // attempted to apply those replacements, expecting them to fail if - // the nodes and edges they reference had been deleted, they may - // succeed because new nodes and edges reused the ids. - // - // We could be a lot smarter here, keeping track of `LoadConstant` - // nodes and only looking at their out neighbours. - let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { - break; - }; - h.apply_rewrite(replace)?; - for rem in removes { - // We are optimistically applying these [RemoveLoadConstant] and - // [RemoveConst] rewrites without checking whether the nodes - // they attempt to remove have remaining uses. If they do, then - // the rewrite fails and we move on. - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - let _ = h.apply_rewrite(RemoveConst(const_node)); + self.verify.run_verified_pass(hugr, reg, |hugr: &mut H| { + loop { + // We can only safely apply a single replacement. Applying a + // replacement removes nodes and edges which may be referenced by + // further replacements returned by find_consts. Even worse, if we + // attempted to apply those replacements, expecting them to fail if + // the nodes and edges they reference had been deleted, they may + // succeed because new nodes and edges reused the ids. + // + // We could be a lot smarter here, keeping track of `LoadConstant` + // nodes and only looking at their out neighbours. + let Some((replace, removes)) = find_consts(hugr, hugr.nodes(), reg).next() else { + break Ok(()); + }; + hugr.apply_rewrite(replace)?; + for rem in removes { + // We are optimistically applying these [RemoveLoadConstant] and + // [RemoveConst] rewrites without checking whether the nodes + // they attempt to remove have remaining uses. If they do, then + // the rewrite fails and we move on. + if let Ok(const_node) = hugr.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + let _ = hugr.apply_rewrite(RemoveConst(const_node)); + } } } - } - self.verify_impl("output", h, reg) + }) } } @@ -275,7 +251,7 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option< /// Exhaustively apply constant folding to a HUGR. pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - ConstFoldConfig::default().run(h, reg).unwrap() + ConstantFoldPass::default().run(h, reg).unwrap() } #[cfg(test)] diff --git a/hugr/src/algorithm/verify.rs b/hugr/src/algorithm/verify.rs new file mode 100644 index 0000000000..475636734a --- /dev/null +++ b/hugr/src/algorithm/verify.rs @@ -0,0 +1,102 @@ +//! Provides [VerifyLevel] with tools to run passes with configurable +//! verification. + +use thiserror::Error; + +use crate::{ + extension::ExtensionRegistry, + hugr::{HugrMut, ValidationError}, + HugrView, +}; + +#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)] +/// A type for running [HugrMut] algorithms with verification. +/// +/// Provides [VerifyLevel::run_verified_pass] to invoke a closure with pre and +/// post verification. +/// +/// The default level is `None` because verification can be expensive. +pub enum VerifyLevel { + /// Do no verification. + None, + /// Verify using [HugrView::validate_no_extensions]. This is useful when you + /// do not expect valid Extension annotations on Nodes. + /// + /// [HugrView::validate_no_extensions]: crate::HugrView::validate_no_extensions + WithoutExtensions, + /// Verify using [HugrView::validate]. + /// + /// [HugrView::validate]: crate::HugrView::validate + WithExtensions, +} + +#[derive(Error, Debug)] +#[allow(missing_docs)] +pub enum VerifyError { + #[error("Failed to verify input HUGR: {err}\n{pretty_hugr}")] + InputError { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to verify output HUGR: {err}\n{pretty_hugr}")] + OutputError { + #[source] + err: ValidationError, + pretty_hugr: String, + }, +} + +impl Default for VerifyLevel { + fn default() -> Self { + if cfg!(test) { + // Many tests fail when run with Self::WithExtensions + Self::WithoutExtensions + } else { + Self::None + } + } +} + +impl VerifyLevel { + /// Run an operation on a [HugrMut]. `hugr` will be verified according to + /// [self](VerifyLevel), then `pass` will be invoked. If `pass` succeeds + /// then `hugr` will be verified again. + pub fn run_verified_pass( + &self, + hugr: &mut H, + reg: &ExtensionRegistry, + pass: impl FnOnce(&mut H) -> Result, + ) -> Result + where + VerifyError: Into, + { + self.verify_impl(hugr, reg, |err, pretty_hugr| VerifyError::InputError { + err, + pretty_hugr, + })?; + let result = pass(hugr)?; + self.verify_impl(hugr, reg, |err, pretty_hugr| VerifyError::OutputError { + err, + pretty_hugr, + })?; + Ok(result) + } + + fn verify_impl( + &self, + hugr: &impl HugrView, + reg: &ExtensionRegistry, + mk_err: impl FnOnce(ValidationError, String) -> VerifyError, + ) -> Result<(), E> + where + VerifyError: Into, + { + match self { + VerifyLevel::None => Ok(()), + VerifyLevel::WithoutExtensions => hugr.validate_no_extensions(reg), + VerifyLevel::WithExtensions => hugr.validate(reg), + } + .map_err(|err| mk_err(err, hugr.mermaid_string()).into()) + } +}