Skip to content

Commit

Permalink
Merge branch 'main' into feat/serialise-rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada authored Sep 27, 2023
2 parents d068c3b + d8fce77 commit 1792211
Show file tree
Hide file tree
Showing 13 changed files with 787 additions and 162 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/with-bindings.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name: Run tests with TKET1 bindings

on:
push:
branches:
- main
# Disabled due to https://github.com/CQCL-DEV/tket2/issues/111
#push:
# branches:
# - main
workflow_dispatch: {}

env:
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ members = ["pyrs", "compile-rewriter", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "19ed0fc" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "af664e3" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
36 changes: 36 additions & 0 deletions pyrs/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::prelude::*;
use hugr::{Hugr, HugrView};
use tket2::extension::REGISTRY;
use tket2::json::TKETDecode;
use tket2::passes::CircuitChunks;
use tket_json_rs::circuit_json::SerialCircuit;

/// Apply a fallible function expecting a hugr on a pytket circuit.
Expand Down Expand Up @@ -52,3 +53,38 @@ pub fn to_hugr_dot(c: Py<PyAny>) -> PyResult<String> {
pub fn to_hugr(c: Py<PyAny>) -> PyResult<Hugr> {
with_hugr(c, |hugr| hugr)
}

#[pyfunction]
pub fn chunks(c: Py<PyAny>, max_chunk_size: usize) -> PyResult<CircuitChunks> {
with_hugr(c, |hugr| CircuitChunks::split(&hugr, max_chunk_size))
}

/// circuit module
pub fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;
m.add_class::<tket2::passes::CircuitChunks>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr, m)?)?;
m.add_function(wrap_pyfunction!(chunks, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}
26 changes: 1 addition & 25 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]
use circuit::try_with_hugr;
use circuit::{add_circuit_module, try_with_hugr};
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket_json_rs::circuit_json::SerialCircuit;
Expand All @@ -25,30 +25,6 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
Ok(())
}

/// circuit module
fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}

/// portmatching module
fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "pattern")?;
Expand Down
13 changes: 12 additions & 1 deletion pyrs/test/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from pyrs.pyrs import passes
from pyrs.pyrs import passes, circuit
from pytket.circuit import Circuit


Expand All @@ -19,6 +19,17 @@ def test_depth_optimise():

assert c.depth() == 2

def test_chunks():
c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3)

assert c.depth() == 3

chunks = circuit.chunks(c, 2)
circuits = chunks.circuits()
chunks.update_circuit(0, circuits[0])
c2 = chunks.reassemble()

assert c2.depth() == 3

# from dataclasses import dataclass
# from typing import Callable, Iterable
Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ pub(crate) mod ops;
pub mod optimiser;
pub mod passes;
pub mod rewrite;
pub use ops::{symbolic_constant_op, Pauli, T2Op};

#[cfg(feature = "portmatching")]
pub mod portmatching;

mod utils;

pub use circuit::Circuit;
pub use ops::{symbolic_constant_op, Pauli, T2Op};
123 changes: 51 additions & 72 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//! it gets too large.
mod eq_circ_class;
mod hugr_hash_set;
mod hugr_pchannel;
mod hugr_pqueue;
pub mod log;
Expand All @@ -24,11 +25,11 @@ pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
use std::num::NonZeroUsize;
use std::time::{Duration, Instant};

use fxhash::FxHashSet;
use hugr::Hugr;

use crate::circuit::CircuitHash;
use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel;
use crate::optimiser::taso::hugr_hash_set::HugrHashSet;
use crate::optimiser::taso::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog};
use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::taso::worker::TasoWorker;
use crate::rewrite::strategy::RewriteStrategy;
Expand Down Expand Up @@ -111,7 +112,7 @@ where
logger.log_best(best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]);
let mut seen_hashes = HugrHashSet::singleton(circ.circuit_hash(), best_circ_cost);

// The priority queue of circuits to be processed (this should not get big)
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;
Expand All @@ -129,19 +130,26 @@ where

let rewrites = self.rewriter.get_rewrites(&circ);
for new_circ in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = (self.cost)(&new_circ);
if pq.len() > PRIORITY_QUEUE_CAPACITY / 2 && new_circ_cost > *pq.max_cost().unwrap()
{
// Ignore this circuit: it's too big
continue;
}
let new_circ_hash = new_circ.circuit_hash();
circ_cnt += 1;
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
if seen_hashes.contains(&new_circ_hash) {
if !seen_hashes.insert(new_circ_hash, new_circ_cost) {
// Ignore this circuit: we've already seen it
continue;
}
pq.push_with_hash_unchecked(new_circ, new_circ_hash);
seen_hashes.insert(new_circ_hash);
circ_cnt += 1;
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);
}

if pq.len() >= PRIORITY_QUEUE_CAPACITY {
// Haircut to keep the queue size manageable
pq.truncate(PRIORITY_QUEUE_CAPACITY / 2);
seen_hashes.clear_over(*pq.max_cost().unwrap());
}

if let Some(timeout) = timeout {
Expand Down Expand Up @@ -172,51 +180,37 @@ where
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

// multi-consumer priority channel for queuing circuits to be processed by the workers
let (tx_work, rx_work) =
let mut pq =
HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads);
// channel for sending circuits from threads back to main
let (tx_result, rx_result) = crossbeam_channel::unbounded();

let initial_circ_hash = circ.circuit_hash();
let mut best_circ = circ.clone();
let mut best_circ_cost = (self.cost)(&best_circ);
logger.log_best(best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]);

// Each worker waits for circuits to scan for rewrites using all the
// patterns and sends the results back to main.
let joins: Vec<_> = (0..n_threads)
.map(|i| {
TasoWorker::spawn(
rx_work.clone(),
tx_result.clone(),
pq.pop.clone().unwrap(),
pq.push.clone().unwrap(),
self.rewriter.clone(),
self.strategy.clone(),
Some(format!("taso-worker-{i}")),
)
})
.collect();
// Drop our copy of the worker channels, so we don't count as a
// connected worker.
drop(rx_work);
drop(tx_result);

// Queue the initial circuit
tx_work
pq.push
.as_ref()
.unwrap()
.send(vec![(initial_circ_hash, circ.clone())])
.unwrap();
// Drop our copy of the priority queue channels, so we don't count as a
// connected worker.
pq.drop_pop_push();

// A counter of circuits seen.
let mut circ_cnt = 1;

// A counter of jobs sent to the workers.
#[allow(unused)]
let mut jobs_sent = 0usize;
// A counter of completed jobs received from the workers.
#[allow(unused)]
let mut jobs_completed = 0usize;
// TODO: Report dropped jobs in the queue, so we can check for termination.

// Deadline for the optimisation timeout
Expand All @@ -225,66 +219,51 @@ where
Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)),
};

// Process worker results until we have seen all the circuits, or we run
// out of time.
// Main loop: log best circuits as they come in from the priority queue,
// until the timeout is reached.
let mut timeout_flag = false;
loop {
select! {
recv(rx_result) -> msg => {
recv(pq.log) -> msg => {
match msg {
Ok(hashed_circs) => {
let send_result = tracing::trace_span!(target: "taso::metrics", "recv_result").in_scope(|| {
jobs_completed += 1;
for (circ_hash, circ) in &hashed_circs {
circ_cnt += 1;
logger.log_progress(circ_cnt, None, seen_hashes.len());
if seen_hashes.contains(circ_hash) {
continue;
}
seen_hashes.insert(*circ_hash);

let cost = (self.cost)(circ);

// Check if we got a new best circuit
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost;
logger.log_best(best_circ_cost);
}
jobs_sent += 1;
}
// Fill the workqueue with data from pq
tx_work.send(hashed_circs)
});
if send_result.is_err() {
eprintln!("All our workers panicked. Stopping optimisation.");
break;
}

// If there is no more data to process, we are done.
//
// TODO: Report dropped jobs in the workers, so we can check for termination.
//if jobs_sent == jobs_completed {
// break 'main;
//};
Ok(PriorityChannelLog::NewBestCircuit(circ, cost)) => {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(best_circ_cost);
},
Ok(PriorityChannelLog::CircuitCount(circuit_cnt, seen_cnt)) => {
logger.log_progress(circuit_cnt, None, seen_cnt);
}
Err(crossbeam_channel::RecvError) => {
eprintln!("All our workers panicked. Stopping optimisation.");
eprintln!("Priority queue panicked. Stopping optimisation.");
break;
}
}
}
recv(timeout_event) -> _ => {
timeout_flag = true;
pq.timeout();
break;
}
}
}

logger.log_processing_end(circ_cnt, best_circ_cost, true, timeout_flag);
// Empty the log from the priority queue and store final circuit count.
let mut circuit_cnt = None;
while let Ok(log) = pq.log.recv() {
match log {
PriorityChannelLog::NewBestCircuit(circ, cost) => {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(best_circ_cost);
}
PriorityChannelLog::CircuitCount(circ_cnt, _) => {
circuit_cnt = Some(circ_cnt);
}
}
}
logger.log_processing_end(circuit_cnt.unwrap_or(0), best_circ_cost, true, timeout_flag);

// Drop the channel so the threads know to stop.
drop(tx_work);
joins.into_iter().for_each(|j| j.join().unwrap());

best_circ
Expand Down
Loading

0 comments on commit 1792211

Please sign in to comment.