From 6eb6d56e53e0487c8cd2218d889e2f69cab49884 Mon Sep 17 00:00:00 2001 From: doug-q <141026920+doug-q@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:26:30 +0100 Subject: [PATCH] feat: Add `ValidationLevel` tooling and apply to `constant_fold_pass` (#1035) --- hugr-passes/src/const_fold.rs | 109 ++++++++++++++++++++-------------- hugr-passes/src/lib.rs | 1 + hugr-passes/src/validation.rs | 96 ++++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 46 deletions(-) create mode 100644 hugr-passes/src/validation.rs diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index bead78c50..b882699d5 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -5,7 +5,7 @@ use std::collections::{BTreeSet, HashMap}; use itertools::Itertools; use thiserror::Error; -use hugr_core::hugr::{SimpleReplacementError, ValidationError}; +use hugr_core::hugr::SimpleReplacementError; use hugr_core::types::SumType; use hugr_core::Direction; use hugr_core::{ @@ -23,17 +23,71 @@ use hugr_core::{ Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; +use crate::validation::{ValidatePassError, ValidationLevel}; + #[derive(Error, Debug)] #[allow(missing_docs)] pub enum ConstFoldError { - #[error("Failed to verify {label} HUGR: {err}")] - VerifyError { - label: String, - #[source] - err: ValidationError, - }, #[error(transparent)] - SimpleReplaceError(#[from] SimpleReplacementError), + SimpleReplacementError(#[from] SimpleReplacementError), + #[error(transparent)] + ValidationError(#[from] ValidatePassError), +} + +#[derive(Debug, Clone, Copy, Default)] +/// A configuration for the Constant Folding pass. +pub struct ConstantFoldPass { + validation: ValidationLevel, +} + +impl ConstantFoldPass { + /// Create a new `ConstFoldConfig` with default configuration. + pub fn new() -> Self { + Self::default() + } + + /// Build a `ConstFoldConfig` with the given [ValidationLevel]. + pub fn validation_level(mut self, level: ValidationLevel) -> Self { + self.validation = level; + self + } + + /// Run the Constant Folding pass. + pub fn run( + &self, + hugr: &mut H, + reg: &ExtensionRegistry, + ) -> Result<(), ConstFoldError> { + self.validation + .run_validated_pass(hugr, reg, |hugr: &mut H, _| { + loop { + // We can only safely apply a single replacement. Applying a + // replacement removes nodes and edges which may be referenced by + // further replacements returned by find_consts. Even worse, if we + // attempted to apply those replacements, expecting them to fail if + // the nodes and edges they reference had been deleted, they may + // succeed because new nodes and edges reused the ids. + // + // We could be a lot smarter here, keeping track of `LoadConstant` + // nodes and only looking at their out neighbours. + let Some((replace, removes)) = find_consts(hugr, hugr.nodes(), reg).next() + else { + break Ok(()); + }; + hugr.apply_rewrite(replace)?; + for rem in removes { + // We are optimistically applying these [RemoveLoadConstant] and + // [RemoveConst] rewrites without checking whether the nodes + // they attempt to remove have remaining uses. If they do, then + // the rewrite fails and we move on. + if let Ok(const_node) = hugr.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + let _ = hugr.apply_rewrite(RemoveConst(const_node)); + } + } + } + }) + } } /// Tag some output constants with [`OutgoingPort`] inferred from the ordering. @@ -186,44 +240,7 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option< /// Exhaustively apply constant folding to a HUGR. pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - #[cfg(test)] - let verify = |label, h: &H| { - h.validate_no_extensions(reg).unwrap_or_else(|err| { - panic!( - "constant_fold_pass: failed to verify {label} HUGR: {err}\n{}", - h.mermaid_string() - ) - }) - }; - #[cfg(test)] - verify("input", h); - loop { - // We can only safely apply a single replacement. Applying a - // replacement removes nodes and edges which may be referenced by - // further replacements returned by find_consts. Even worse, if we - // attempted to apply those replacements, expecting them to fail if - // the nodes and edges they reference had been deleted, they may - // succeed because new nodes and edges reused the ids. - // - // We could be a lot smarter here, keeping track of `LoadConstant` - // nodes and only looking at their out neighbours. - let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { - break; - }; - h.apply_rewrite(replace).unwrap(); - for rem in removes { - // We are optimistically applying these [RemoveLoadConstant] and - // [RemoveConst] rewrites without checking whether the nodes - // they attempt to remove have remaining uses. If they do, then - // the rewrite fails and we move on. - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - let _ = h.apply_rewrite(RemoveConst(const_node)); - } - } - } - #[cfg(test)] - verify("output", h); + ConstantFoldPass::default().run(h, reg).unwrap() } #[cfg(test)] diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 803196144..f6e09b71b 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -4,3 +4,4 @@ pub mod const_fold; mod half_node; pub mod merge_bbs; pub mod nest_cfgs; +pub mod validation; diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs new file mode 100644 index 000000000..68fa43601 --- /dev/null +++ b/hugr-passes/src/validation.rs @@ -0,0 +1,96 @@ +//! Provides [ValidationLevel] with tools to run passes with configurable +//! validation. + +use thiserror::Error; + +use hugr_core::{ + extension::ExtensionRegistry, + hugr::{hugrmut::HugrMut, ValidationError}, + HugrView, +}; + +#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)] +/// A type for running [HugrMut] algorithms with verification. +/// +/// Provides [ValidationLevel::run_validated_pass] to invoke a closure with pre and post +/// validation. +/// +/// The default level is `None` because validation can be expensive. +pub enum ValidationLevel { + /// Do no verification. + None, + /// Validate using [HugrView::validate_no_extensions]. This is useful when you + /// do not expect valid Extension annotations on Nodes. + WithoutExtensions, + /// Validate using [HugrView::validate]. + WithExtensions, +} + +#[derive(Error, Debug)] +#[allow(missing_docs)] +pub enum ValidatePassError { + #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] + InputError { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] + OutputError { + #[source] + err: ValidationError, + pretty_hugr: String, + }, +} + +impl Default for ValidationLevel { + fn default() -> Self { + if cfg!(test) { + // Many tests fail when run with Self::WithExtensions + Self::WithoutExtensions + } else { + Self::None + } + } +} + +impl ValidationLevel { + /// Run an operation on a [HugrMut]. `hugr` will be verified according to + /// [self](ValidationLevel), then `pass` will be invoked. If `pass` succeeds + /// then `hugr` will be verified again. + pub fn run_validated_pass( + &self, + hugr: &mut H, + reg: &ExtensionRegistry, + pass: impl FnOnce(&mut H, &Self) -> Result, + ) -> Result + where + ValidatePassError: Into, + { + self.validation_impl(hugr, reg, |err, pretty_hugr| { + ValidatePassError::InputError { err, pretty_hugr } + })?; + let result = pass(hugr, self)?; + self.validation_impl(hugr, reg, |err, pretty_hugr| { + ValidatePassError::OutputError { err, pretty_hugr } + })?; + Ok(result) + } + + fn validation_impl( + &self, + hugr: &impl HugrView, + reg: &ExtensionRegistry, + mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, + ) -> Result<(), E> + where + ValidatePassError: Into, + { + match self { + ValidationLevel::None => Ok(()), + ValidationLevel::WithoutExtensions => hugr.validate_no_extensions(reg), + ValidationLevel::WithExtensions => hugr.validate(reg), + } + .map_err(|err| mk_err(err, hugr.mermaid_string()).into()) + } +}