Skip to content

Commit

Permalink
feat: Add ValidationLevel tooling and apply to constant_fold_pass (
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q authored Jun 6, 2024
1 parent e993580 commit 6eb6d56
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 46 deletions.
109 changes: 63 additions & 46 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::collections::{BTreeSet, HashMap};
use itertools::Itertools;
use thiserror::Error;

use hugr_core::hugr::{SimpleReplacementError, ValidationError};
use hugr_core::hugr::SimpleReplacementError;
use hugr_core::types::SumType;
use hugr_core::Direction;
use hugr_core::{
Expand All @@ -23,17 +23,71 @@ use hugr_core::{
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};

use crate::validation::{ValidatePassError, ValidationLevel};

#[derive(Error, Debug)]
#[allow(missing_docs)]
pub enum ConstFoldError {
#[error("Failed to verify {label} HUGR: {err}")]
VerifyError {
label: String,
#[source]
err: ValidationError,
},
#[error(transparent)]
SimpleReplaceError(#[from] SimpleReplacementError),
SimpleReplacementError(#[from] SimpleReplacementError),
#[error(transparent)]
ValidationError(#[from] ValidatePassError),
}

#[derive(Debug, Clone, Copy, Default)]
/// A configuration for the Constant Folding pass.
pub struct ConstantFoldPass {
validation: ValidationLevel,
}

impl ConstantFoldPass {
/// Create a new `ConstFoldConfig` with default configuration.
pub fn new() -> Self {
Self::default()
}

/// Build a `ConstFoldConfig` with the given [ValidationLevel].
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}

/// Run the Constant Folding pass.
pub fn run<H: HugrMut>(
&self,
hugr: &mut H,
reg: &ExtensionRegistry,
) -> Result<(), ConstFoldError> {
self.validation
.run_validated_pass(hugr, reg, |hugr: &mut H, _| {
loop {
// We can only safely apply a single replacement. Applying a
// replacement removes nodes and edges which may be referenced by
// further replacements returned by find_consts. Even worse, if we
// attempted to apply those replacements, expecting them to fail if
// the nodes and edges they reference had been deleted, they may
// succeed because new nodes and edges reused the ids.
//
// We could be a lot smarter here, keeping track of `LoadConstant`
// nodes and only looking at their out neighbours.
let Some((replace, removes)) = find_consts(hugr, hugr.nodes(), reg).next()
else {
break Ok(());
};
hugr.apply_rewrite(replace)?;
for rem in removes {
// We are optimistically applying these [RemoveLoadConstant] and
// [RemoveConst] rewrites without checking whether the nodes
// they attempt to remove have remaining uses. If they do, then
// the rewrite fails and we move on.
if let Ok(const_node) = hugr.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
let _ = hugr.apply_rewrite(RemoveConst(const_node));
}
}
}
})
}
}

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
Expand Down Expand Up @@ -186,44 +240,7 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<

/// Exhaustively apply constant folding to a HUGR.
pub fn constant_fold_pass<H: HugrMut>(h: &mut H, reg: &ExtensionRegistry) {
#[cfg(test)]
let verify = |label, h: &H| {
h.validate_no_extensions(reg).unwrap_or_else(|err| {
panic!(
"constant_fold_pass: failed to verify {label} HUGR: {err}\n{}",
h.mermaid_string()
)
})
};
#[cfg(test)]
verify("input", h);
loop {
// We can only safely apply a single replacement. Applying a
// replacement removes nodes and edges which may be referenced by
// further replacements returned by find_consts. Even worse, if we
// attempted to apply those replacements, expecting them to fail if
// the nodes and edges they reference had been deleted, they may
// succeed because new nodes and edges reused the ids.
//
// We could be a lot smarter here, keeping track of `LoadConstant`
// nodes and only looking at their out neighbours.
let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else {
break;
};
h.apply_rewrite(replace).unwrap();
for rem in removes {
// We are optimistically applying these [RemoveLoadConstant] and
// [RemoveConst] rewrites without checking whether the nodes
// they attempt to remove have remaining uses. If they do, then
// the rewrite fails and we move on.
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
let _ = h.apply_rewrite(RemoveConst(const_node));
}
}
}
#[cfg(test)]
verify("output", h);
ConstantFoldPass::default().run(h, reg).unwrap()
}

#[cfg(test)]
Expand Down
1 change: 1 addition & 0 deletions hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod const_fold;
mod half_node;
pub mod merge_bbs;
pub mod nest_cfgs;
pub mod validation;
96 changes: 96 additions & 0 deletions hugr-passes/src/validation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//! Provides [ValidationLevel] with tools to run passes with configurable
//! validation.
use thiserror::Error;

use hugr_core::{
extension::ExtensionRegistry,
hugr::{hugrmut::HugrMut, ValidationError},
HugrView,
};

#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)]
/// A type for running [HugrMut] algorithms with verification.
///
/// Provides [ValidationLevel::run_validated_pass] to invoke a closure with pre and post
/// validation.
///
/// The default level is `None` because validation can be expensive.
pub enum ValidationLevel {
/// Do no verification.
None,
/// Validate using [HugrView::validate_no_extensions]. This is useful when you
/// do not expect valid Extension annotations on Nodes.
WithoutExtensions,
/// Validate using [HugrView::validate].
WithExtensions,
}

#[derive(Error, Debug)]
#[allow(missing_docs)]
pub enum ValidatePassError {
#[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
InputError {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
OutputError {
#[source]
err: ValidationError,
pretty_hugr: String,
},
}

impl Default for ValidationLevel {
fn default() -> Self {
if cfg!(test) {
// Many tests fail when run with Self::WithExtensions
Self::WithoutExtensions
} else {
Self::None
}
}
}

impl ValidationLevel {
/// Run an operation on a [HugrMut]. `hugr` will be verified according to
/// [self](ValidationLevel), then `pass` will be invoked. If `pass` succeeds
/// then `hugr` will be verified again.
pub fn run_validated_pass<H: HugrMut, E, T>(
&self,
hugr: &mut H,
reg: &ExtensionRegistry,
pass: impl FnOnce(&mut H, &Self) -> Result<T, E>,
) -> Result<T, E>
where
ValidatePassError: Into<E>,
{
self.validation_impl(hugr, reg, |err, pretty_hugr| {
ValidatePassError::InputError { err, pretty_hugr }
})?;
let result = pass(hugr, self)?;
self.validation_impl(hugr, reg, |err, pretty_hugr| {
ValidatePassError::OutputError { err, pretty_hugr }
})?;
Ok(result)
}

fn validation_impl<E>(
&self,
hugr: &impl HugrView,
reg: &ExtensionRegistry,
mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError,
) -> Result<(), E>
where
ValidatePassError: Into<E>,
{
match self {
ValidationLevel::None => Ok(()),
ValidationLevel::WithoutExtensions => hugr.validate_no_extensions(reg),
ValidationLevel::WithExtensions => hugr.validate(reg),
}
.map_err(|err| mk_err(err, hugr.mermaid_string()).into())
}
}

0 comments on commit 6eb6d56

Please sign in to comment.