diff --git a/Cargo.toml b/Cargo.toml index d3b32cf9..dc116342 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,6 @@ missing_docs = "warn" # Make sure to run `just recompile-eccs` if the hugr serialisation format changes. hugr = "0.14.0" hugr-core = "0.14.0" -hugr-cli = "0.14.0" portgraph = "0.12" pyo3 = "0.23.3" itertools = "0.13.0" diff --git a/tket2-hseries/Cargo.toml b/tket2-hseries/Cargo.toml index 8da0a73c..106392cc 100644 --- a/tket2-hseries/Cargo.toml +++ b/tket2-hseries/Cargo.toml @@ -15,7 +15,7 @@ categories = ["compilers"] [features] default = ["cli"] -cli = ["dep:clap", "dep:hugr-cli"] +cli = ["dep:clap"] [[bin]] name = "tket2-hseries" @@ -32,7 +32,6 @@ strum.workspace = true strum_macros.workspace = true itertools.workspace = true clap = { workspace = true, optional = true } -hugr-cli = { workspace = true, optional = true } derive_more = { workspace = true, features = [ "error", "display", diff --git a/tket2-hseries/src/cli.rs b/tket2-hseries/src/cli.rs index 2bbb41af..67fcd7e4 100644 --- a/tket2-hseries/src/cli.rs +++ b/tket2-hseries/src/cli.rs @@ -10,5 +10,5 @@ use clap::Parser; #[non_exhaustive] pub enum CliArgs { /// Generate serialized extensions. - GenExtensions(hugr_cli::extensions::ExtArgs), + GenExtensions(tket2::hugr::cli::extensions::ExtArgs), } diff --git a/tket2-hseries/src/lib.rs b/tket2-hseries/src/lib.rs index 0a8be413..f59b60e1 100644 --- a/tket2-hseries/src/lib.rs +++ b/tket2-hseries/src/lib.rs @@ -3,10 +3,12 @@ use derive_more::{Display, Error, From}; use hugr::{ algorithms::{ - force_order, + const_fold::{ConstFoldError, ConstantFoldPass}, + force_order, monomorphize, remove_polyfuncs, validation::{ValidatePassError, ValidationLevel}, }, - hugr::{hugrmut::HugrMut, HugrError}, + hugr::HugrError, + Hugr, HugrView, }; use tket2::Tk2Op; @@ -26,9 +28,21 @@ pub mod lazify_measure; /// Returns an error if this cannot be done. /// /// To construct a `QSystemPass` use [Default::default]. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy)] pub struct QSystemPass { validation_level: ValidationLevel, + constant_fold: bool, + monomorphize: bool, +} + +impl Default for QSystemPass { + fn default() -> Self { + Self { + validation_level: ValidationLevel::default(), + constant_fold: false, + monomorphize: true, + } + } } #[derive(Error, Debug, Display, From)] @@ -43,12 +57,25 @@ pub enum QSystemPassError { ForceOrderError(HugrError), /// An error from the component [LowerTket2ToQSystemPass] pass. LowerTk2Error(LowerTk2Error), + /// An error from the component [ConstantFoldErrorPass] pass. + ConstantFoldError(ConstFoldError), } impl QSystemPass { /// Run `QSystemPass` on the given [HugrMut]. `registry` is used for /// validation, if enabled. - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), QSystemPassError> { + pub fn run(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> { + if self.monomorphize { + self.validation_level.run_validated_pass(hugr, |hugr, _| { + *hugr = remove_polyfuncs(monomorphize(hugr.clone())); + + Ok::<_, QSystemPassError>(()) + })?; + } + + if self.constant_fold { + self.constant_fold().run(hugr)?; + } self.lower_tk2().run(hugr)?; self.lazify_measure().run(hugr)?; self.validation_level.run_validated_pass(hugr, |hugr, _| { @@ -77,11 +104,29 @@ impl QSystemPass { LazifyMeasurePass::default().with_validation_level(self.validation_level) } + fn constant_fold(&self) -> ConstantFoldPass { + ConstantFoldPass::default().validation_level(self.validation_level) + } + /// Returns a new `QSystemPass` with the given [ValidationLevel]. pub fn with_validation_level(mut self, level: ValidationLevel) -> Self { self.validation_level = level; self } + + /// Returns a new `QSystemPass` with constant folding enabled according to + /// `constant_fold`. + pub fn with_constant_fold(mut self, constant_fold: bool) -> Self { + self.constant_fold = constant_fold; + self + } + + /// Returns a new `QSystemPass` with monomorphization enabled according to + /// `monomorphsze`. + pub fn with_monormophize(mut self, monomorphize: bool) -> Self { + self.monomorphize = monomorphize; + self + } } #[cfg(test)] diff --git a/uv.lock b/uv.lock index e4ac503a..8c2b94fb 100644 --- a/uv.lock +++ b/uv.lock @@ -823,7 +823,7 @@ wheels = [ [[package]] name = "tket2" -version = "0.5.1" +version = "0.6.0" source = { editable = "tket2-py" } dependencies = [ { name = "hugr" },