Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 14, 2024
1 parent e6c97e6 commit 6628b0f
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 92 deletions.
32 changes: 1 addition & 31 deletions hugr/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,4 @@ pub mod const_fold;
mod half_node;
pub mod merge_bbs;
pub mod nest_cfgs;

#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)]
/// A type for algorithms to take as configuration, specifying how much
/// verification they should do. Algorithms that accept this configuration
/// should at least verify that input HUGRs are valid, and that output HUGRs are
/// valid.
///
/// The default level is `None` because verification can be expensive.
pub enum VerifyLevel {
/// Do no verification.
None,
/// Verify using [HugrView::validate_no_extensions]. This is useful when you
/// do not expect valid Extension annotations on Nodes.
///
/// [HugrView::validate_no_extensions]: crate::HugrView::validate_no_extensions
WithoutExtensions,
/// Verify using [HugrView::validate].
///
/// [HugrView::validate]: crate::HugrView::validate
WithExtensions,
}

impl Default for VerifyLevel {
fn default() -> Self {
if cfg!(test) {
Self::WithoutExtensions
} else {
Self::None
}
}
}
pub mod verify;
98 changes: 37 additions & 61 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::collections::{BTreeSet, HashMap};
use itertools::Itertools;
use thiserror::Error;

use crate::hugr::{SimpleReplacementError, ValidationError};
use crate::algorithm::verify::{VerifyError, VerifyLevel};
use crate::hugr::SimpleReplacementError;
use crate::types::SumType;
use crate::Direction;
use crate::{
Expand All @@ -22,91 +23,66 @@ use crate::{
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};

use super::VerifyLevel;

#[derive(Error, Debug)]
#[allow(missing_docs)]
pub enum ConstFoldError {
#[error("Failed to verify {label} HUGR: {err}")]
VerifyError {
label: String,
#[source]
err: ValidationError,
},
#[error(transparent)]
SimpleReplaceError(#[from] SimpleReplacementError),
}

impl ConstFoldError {
fn verify_err(label: impl Into<String>, err: ValidationError) -> Self {
Self::VerifyError {
label: label.into(),
err,
}
}
SimpleReplacementError(#[from] SimpleReplacementError),
#[error(transparent)]
VerifyError(#[from] VerifyError),
}

#[derive(Debug, Clone, Copy, Default)]
/// A configuration for the Constant Folding pass.
pub struct ConstFoldConfig {
pub struct ConstantFoldPass {
verify: VerifyLevel,
}

impl ConstFoldConfig {
impl ConstantFoldPass {
/// Create a new `ConstFoldConfig` with default configuration.
pub fn new() -> Self {
Self::default()
}

/// Build a `ConstFoldConfig` with the given [VerifyLevel].
pub fn with_verify(mut self, verify: VerifyLevel) -> Self {
pub fn verify_level(mut self, verify: VerifyLevel) -> Self {
self.verify = verify;
self
}

fn verify_impl(
/// Run the Constant Folding pass.
pub fn run<H: HugrMut>(
&self,
label: &str,
h: &impl HugrView,
hugr: &mut H,
reg: &ExtensionRegistry,
) -> Result<(), ConstFoldError> {
match self.verify {
VerifyLevel::None => Ok(()),
VerifyLevel::WithoutExtensions => h.validate_no_extensions(reg),
VerifyLevel::WithExtensions => h.validate(reg),
}
.map_err(|err| ConstFoldError::verify_err(label, err))
}

/// Run the Constant Folding pass.
pub fn run(&self, h: &mut impl HugrMut, reg: &ExtensionRegistry) -> Result<(), ConstFoldError> {
self.verify_impl("input", h, reg)?;
loop {
// We can only safely apply a single replacement. Applying a
// replacement removes nodes and edges which may be referenced by
// further replacements returned by find_consts. Even worse, if we
// attempted to apply those replacements, expecting them to fail if
// the nodes and edges they reference had been deleted, they may
// succeed because new nodes and edges reused the ids.
//
// We could be a lot smarter here, keeping track of `LoadConstant`
// nodes and only looking at their out neighbours.
let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else {
break;
};
h.apply_rewrite(replace)?;
for rem in removes {
// We are optimistically applying these [RemoveLoadConstant] and
// [RemoveConst] rewrites without checking whether the nodes
// they attempt to remove have remaining uses. If they do, then
// the rewrite fails and we move on.
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
let _ = h.apply_rewrite(RemoveConst(const_node));
self.verify.run_verified_pass(hugr, reg, |hugr: &mut H| {
loop {
// We can only safely apply a single replacement. Applying a
// replacement removes nodes and edges which may be referenced by
// further replacements returned by find_consts. Even worse, if we
// attempted to apply those replacements, expecting them to fail if
// the nodes and edges they reference had been deleted, they may
// succeed because new nodes and edges reused the ids.
//
// We could be a lot smarter here, keeping track of `LoadConstant`
// nodes and only looking at their out neighbours.
let Some((replace, removes)) = find_consts(hugr, hugr.nodes(), reg).next() else {
break Ok(());
};
hugr.apply_rewrite(replace)?;
for rem in removes {
// We are optimistically applying these [RemoveLoadConstant] and
// [RemoveConst] rewrites without checking whether the nodes
// they attempt to remove have remaining uses. If they do, then
// the rewrite fails and we move on.
if let Ok(const_node) = hugr.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
let _ = hugr.apply_rewrite(RemoveConst(const_node));
}
}
}
}
self.verify_impl("output", h, reg)
})
}
}

Expand Down Expand Up @@ -275,7 +251,7 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<

/// Exhaustively apply constant folding to a HUGR.
pub fn constant_fold_pass<H: HugrMut>(h: &mut H, reg: &ExtensionRegistry) {
ConstFoldConfig::default().run(h, reg).unwrap()
ConstantFoldPass::default().run(h, reg).unwrap()
}

#[cfg(test)]
Expand Down
102 changes: 102 additions & 0 deletions hugr/src/algorithm/verify.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//! Provides [VerifyLevel] with tools to run passes with configurable
//! verification.
use thiserror::Error;

use crate::{
extension::ExtensionRegistry,
hugr::{HugrMut, ValidationError},
HugrView,
};

#[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)]
/// A type for running [HugrMut] algorithms with verification.
///
/// Provides [VerifyLevel::run_verified_pass] to invoke a closure with pre and
/// post verification.
///
/// The default level is `None` because verification can be expensive.
pub enum VerifyLevel {
/// Do no verification.
None,
/// Verify using [HugrView::validate_no_extensions]. This is useful when you
/// do not expect valid Extension annotations on Nodes.
///
/// [HugrView::validate_no_extensions]: crate::HugrView::validate_no_extensions
WithoutExtensions,
/// Verify using [HugrView::validate].
///
/// [HugrView::validate]: crate::HugrView::validate
WithExtensions,
}

#[derive(Error, Debug)]
#[allow(missing_docs)]
pub enum VerifyError {
#[error("Failed to verify input HUGR: {err}\n{pretty_hugr}")]
InputError {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error("Failed to verify output HUGR: {err}\n{pretty_hugr}")]
OutputError {
#[source]
err: ValidationError,
pretty_hugr: String,
},
}

impl Default for VerifyLevel {
fn default() -> Self {
if cfg!(test) {
// Many tests fail when run with Self::WithExtensions
Self::WithoutExtensions
} else {
Self::None
}
}
}

impl VerifyLevel {
/// Run an operation on a [HugrMut]. `hugr` will be verified according to
/// [self](VerifyLevel), then `pass` will be invoked. If `pass` succeeds
/// then `hugr` will be verified again.
pub fn run_verified_pass<H: HugrMut, E, T>(
&self,
hugr: &mut H,
reg: &ExtensionRegistry,
pass: impl FnOnce(&mut H) -> Result<T, E>,
) -> Result<T, E>
where
VerifyError: Into<E>,
{
self.verify_impl(hugr, reg, |err, pretty_hugr| VerifyError::InputError {
err,
pretty_hugr,
})?;
let result = pass(hugr)?;
self.verify_impl(hugr, reg, |err, pretty_hugr| VerifyError::OutputError {
err,
pretty_hugr,
})?;
Ok(result)
}

fn verify_impl<E>(
&self,
hugr: &impl HugrView,
reg: &ExtensionRegistry,
mk_err: impl FnOnce(ValidationError, String) -> VerifyError,
) -> Result<(), E>
where
VerifyError: Into<E>,
{
match self {
VerifyLevel::None => Ok(()),
VerifyLevel::WithoutExtensions => hugr.validate_no_extensions(reg),
VerifyLevel::WithExtensions => hugr.validate(reg),
}
.map_err(|err| mk_err(err, hugr.mermaid_string()).into())
}
}

0 comments on commit 6628b0f

Please sign in to comment.