Skip to content

Commit

Permalink
feat: Add monomorphization and constant folding to QSystemPass
Browse files Browse the repository at this point in the history
Closes #729

Note that constant folding is disabled by default as it currently does
not work on modules.
  • Loading branch information
doug-q committed Dec 17, 2024
1 parent 95090a2 commit cf25210
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 9 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ missing_docs = "warn"
# Make sure to run `just recompile-eccs` if the hugr serialisation format changes.
hugr = "0.14.0"
hugr-core = "0.14.0"
hugr-cli = "0.14.0"
portgraph = "0.12"
pyo3 = "0.23.3"
itertools = "0.13.0"
Expand Down
3 changes: 1 addition & 2 deletions tket2-hseries/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ categories = ["compilers"]

[features]
default = ["cli"]
cli = ["dep:clap", "dep:hugr-cli"]
cli = ["dep:clap"]

[[bin]]
name = "tket2-hseries"
Expand All @@ -32,7 +32,6 @@ strum.workspace = true
strum_macros.workspace = true
itertools.workspace = true
clap = { workspace = true, optional = true }
hugr-cli = { workspace = true, optional = true }
derive_more = { workspace = true, features = [
"error",
"display",
Expand Down
2 changes: 1 addition & 1 deletion tket2-hseries/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ use clap::Parser;
#[non_exhaustive]
pub enum CliArgs {
/// Generate serialized extensions.
GenExtensions(hugr_cli::extensions::ExtArgs),
GenExtensions(tket2::hugr::cli::extensions::ExtArgs),
}
53 changes: 49 additions & 4 deletions tket2-hseries/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
use derive_more::{Display, Error, From};
use hugr::{
algorithms::{
force_order,
const_fold::{ConstFoldError, ConstantFoldPass},
force_order, monomorphize, remove_polyfuncs,
validation::{ValidatePassError, ValidationLevel},
},
hugr::{hugrmut::HugrMut, HugrError},
hugr::HugrError,
Hugr, HugrView,
};
use tket2::Tk2Op;

Expand All @@ -26,9 +28,21 @@ pub mod lazify_measure;
/// Returns an error if this cannot be done.
///
/// To construct a `QSystemPass` use [Default::default].
#[derive(Debug, Clone, Copy, Default)]
#[derive(Debug, Clone, Copy)]
pub struct QSystemPass {
validation_level: ValidationLevel,
constant_fold: bool,
monomorphize: bool,
}

impl Default for QSystemPass {
fn default() -> Self {
Self {
validation_level: ValidationLevel::default(),
constant_fold: false,
monomorphize: true,
}
}
}

#[derive(Error, Debug, Display, From)]
Expand All @@ -43,12 +57,25 @@ pub enum QSystemPassError {
ForceOrderError(HugrError),
/// An error from the component [LowerTket2ToQSystemPass] pass.
LowerTk2Error(LowerTk2Error),
/// An error from the component [ConstantFoldErrorPass] pass.
ConstantFoldError(ConstFoldError),
}

impl QSystemPass {
/// Run `QSystemPass` on the given [HugrMut]. `registry` is used for
/// validation, if enabled.
pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), QSystemPassError> {
pub fn run(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> {
if self.monomorphize {
self.validation_level.run_validated_pass(hugr, |hugr, _| {
*hugr = remove_polyfuncs(monomorphize(hugr.clone()));

Ok::<_, QSystemPassError>(())
})?;
}

if self.constant_fold {
self.constant_fold().run(hugr)?;
}
self.lower_tk2().run(hugr)?;
self.lazify_measure().run(hugr)?;
self.validation_level.run_validated_pass(hugr, |hugr, _| {
Expand Down Expand Up @@ -77,11 +104,29 @@ impl QSystemPass {
LazifyMeasurePass::default().with_validation_level(self.validation_level)
}

fn constant_fold(&self) -> ConstantFoldPass {
ConstantFoldPass::default().validation_level(self.validation_level)
}

/// Returns a new `QSystemPass` with the given [ValidationLevel].
pub fn with_validation_level(mut self, level: ValidationLevel) -> Self {
self.validation_level = level;
self
}

/// Returns a new `QSystemPass` with constant folding enabled according to
/// `constant_fold`.
pub fn with_constant_fold(mut self, constant_fold: bool) -> Self {
self.constant_fold = constant_fold;
self
}

/// Returns a new `QSystemPass` with monomorphization enabled according to
/// `monomorphsze`.
pub fn with_monormophize(mut self, monomorphize: bool) -> Self {
self.monomorphize = monomorphize;
self
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit cf25210

Please sign in to comment.