Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada committed Oct 10, 2023
1 parent 631a0e2 commit daafc2c
Showing 1 changed file with 28 additions and 34 deletions.
62 changes: 28 additions & 34 deletions pyrs/src/pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,17 @@ fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> {
})
}

/// TASO optimisation pass.
///
/// HyperTKET's best attempt at optimising a circuit using circuit rewriting
/// and TASO.
///
/// Input can be in any gate set and will be rebased to Nam, i.e. CX + Rz + H.
///
/// Will use at most `max_threads` threads (plus a constant) and take at most
/// `timeout` seconds (plus a constant). Default to the number of cpus and
/// 15min respectively.
/// Rebase a circuit to the Nam gate set (CX, Rz, H) using TKET1.
///
/// Log files will be written to the directory `log_dir` if specified.
/// Acquires the python GIL to call TKET's `auto_rebase_pass`.
///
/// This requires a `nam_6_3.rwr` file in the current directory. The location
/// can alternatively be specified using the `rewriter_dir` argument.
#[pyfunction]
fn taso_optimise(
circ: PyObject,
max_threads: Option<NonZeroUsize>,
rewriter_dir: Option<PathBuf>,
timeout: Option<u64>,
log_dir: Option<PathBuf>,
) -> PyResult<PyObject> {
// Runs the following code:
// ```python
// from pytket.passes.auto_rebase import auto_rebase_pass
// from pytket import OpType
// auto_rebase_pass({OpType.CX, OpType.Rz, OpType.H}).apply(circ)"
// ```
/// Equivalent to running the following code:
/// ```python
/// from pytket.passes.auto_rebase import auto_rebase_pass
/// from pytket import OpType
/// auto_rebase_pass({OpType.CX, OpType.Rz, OpType.H}).apply(circ)"
// ```
fn rebase_nam(circ: &PyObject) -> PyResult<()> {
Python::with_gil(|py| {
let auto_rebase = py
.import("pytket.passes.auto_rebase")?
Expand All @@ -55,40 +37,50 @@ fn taso_optimise(
let locals = [("OpType", &optype)].into_py_dict(py);
let op_set = py.eval("{OpType.CX, OpType.Rz, OpType.H}", None, Some(locals))?;
let rebase_pass = auto_rebase.call1((op_set,))?.getattr("apply")?;
rebase_pass.call1((&circ,)).map(|_| ())
})?;
nam_taso_optimise(circ, max_threads, rewriter_dir, timeout, log_dir)
rebase_pass.call1((circ,)).map(|_| ())
})
}

/// TASO optimisation pass.
///
/// HyperTKET's best attempt at optimising a circuit using circuit rewriting
/// and TASO.
///
/// Input must be in the Nam gate set, i.e. CX + Rz + H.
/// By default, the input circuit will be rebased to Nam, i.e. CX + Rz + H before
/// optimising. This can be deactivated by setting `rebase` to `false`, in which
/// case the circuit is expected to be in the Nam gate set.
///
/// Will use at most `max_threads` threads (plus a constant) and take at most
/// `timeout` seconds (plus a constant). Default to the number of cpus and
/// 30s respectively.
/// 15min respectively.
///
/// Log files will be written to the directory `log_dir` if specified.
///
/// This requires a `nam_6_3.rwr` file in the current directory. The location
/// can alternatively be specified using the `rewriter_dir` argument.
#[pyfunction]
fn nam_taso_optimise(
fn taso_optimise(
circ: PyObject,
max_threads: Option<NonZeroUsize>,
rewriter_dir: Option<PathBuf>,
timeout: Option<u64>,
log_dir: Option<PathBuf>,
rebase: Option<bool>,
) -> PyResult<PyObject> {
// Default parameter values
let rebase = rebase.unwrap_or(true);
let max_threads = max_threads.unwrap_or(num_cpus::get().try_into().unwrap());
let rewrite_dir = rewriter_dir.unwrap_or(PathBuf::from("."));
let timeout = timeout.unwrap_or(30);
// Create log directory if necessary
if let Some(log_dir) = log_dir.as_ref() {
fs::create_dir_all(log_dir)?;
}
// Rebase circuit
if rebase {
rebase_nam(&circ)?;
}
// Logic to choose how to split the circuit
let taso_splits = |n_threads: NonZeroUsize| match n_threads.get() {
n if n >= 7 => (
vec![n, 3, 1],
Expand All @@ -102,7 +94,10 @@ fn nam_taso_optimise(
1 => (vec![1], vec![timeout]),
_ => unreachable!(),
};
// Load rewriter
// TODO: do not hardcode file name
let optimiser = PyDefaultTasoOptimiser::load_precompiled(rewrite_dir.join("nam_6_3.rwr"));
// Optimise
try_update_hugr(circ, |mut circ| {
let n_cx = circ
.commands()
Expand Down Expand Up @@ -135,7 +130,6 @@ fn nam_taso_optimise(
pub(crate) fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "passes")?;
m.add_function(wrap_pyfunction!(greedy_depth_reduce, m)?)?;
m.add_function(wrap_pyfunction!(nam_taso_optimise, m)?)?;
m.add_function(wrap_pyfunction!(taso_optimise, m)?)?;
m.add_class::<tket2::T2Op>()?;
m.add(
Expand Down

0 comments on commit daafc2c

Please sign in to comment.