From e33fc6a8f7820e52954de31991e9e589b7d5f8aa Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Wed, 27 Sep 2023 14:16:00 +0200 Subject: [PATCH 1/3] perf: Reduce TASO hashtable size (#133) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> --- src/optimiser/taso.rs | 123 +++++++--------- src/optimiser/taso/hugr_hash_set.rs | 155 ++++++++++++++++++++ src/optimiser/taso/hugr_pchannel.rs | 218 +++++++++++++++++++++------- src/optimiser/taso/hugr_pqueue.rs | 19 ++- src/optimiser/taso/worker.rs | 2 +- 5 files changed, 387 insertions(+), 130 deletions(-) create mode 100644 src/optimiser/taso/hugr_hash_set.rs diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 2c7f4982..59843d96 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -12,6 +12,7 @@ //! it gets too large. mod eq_circ_class; +mod hugr_hash_set; mod hugr_pchannel; mod hugr_pqueue; pub mod log; @@ -24,11 +25,11 @@ pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; -use fxhash::FxHashSet; use hugr::Hugr; use crate::circuit::CircuitHash; -use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel; +use crate::optimiser::taso::hugr_hash_set::HugrHashSet; +use crate::optimiser::taso::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog}; use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ}; use crate::optimiser::taso::worker::TasoWorker; use crate::rewrite::strategy::RewriteStrategy; @@ -111,7 +112,7 @@ where logger.log_best(best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge - let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); + let mut seen_hashes = HugrHashSet::singleton(circ.circuit_hash(), best_circ_cost); // The priority queue of circuits to be processed (this should not get big) const PRIORITY_QUEUE_CAPACITY: usize = 10_000; @@ -129,19 +130,26 @@ where let rewrites = self.rewriter.get_rewrites(&circ); for new_circ in self.strategy.apply_rewrites(rewrites, &circ) { + let new_circ_cost = (self.cost)(&new_circ); + if pq.len() > PRIORITY_QUEUE_CAPACITY / 2 && new_circ_cost > *pq.max_cost().unwrap() + { + // Ignore this circuit: it's too big + continue; + } let new_circ_hash = new_circ.circuit_hash(); - circ_cnt += 1; logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len()); - if seen_hashes.contains(&new_circ_hash) { + if !seen_hashes.insert(new_circ_hash, new_circ_cost) { + // Ignore this circuit: we've already seen it continue; } - pq.push_with_hash_unchecked(new_circ, new_circ_hash); - seen_hashes.insert(new_circ_hash); + circ_cnt += 1; + pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost); } if pq.len() >= PRIORITY_QUEUE_CAPACITY { // Haircut to keep the queue size manageable pq.truncate(PRIORITY_QUEUE_CAPACITY / 2); + seen_hashes.clear_over(*pq.max_cost().unwrap()); } if let Some(timeout) = timeout { @@ -172,51 +180,37 @@ where const PRIORITY_QUEUE_CAPACITY: usize = 10_000; // multi-consumer priority channel for queuing circuits to be processed by the workers - let (tx_work, rx_work) = + let mut pq = HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads); - // channel for sending circuits from threads back to main - let (tx_result, rx_result) = crossbeam_channel::unbounded(); let initial_circ_hash = circ.circuit_hash(); let mut best_circ = circ.clone(); let mut best_circ_cost = (self.cost)(&best_circ); - logger.log_best(best_circ_cost); - - // Hash of seen circuits. Dot not store circuits as this map gets huge - let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]); // Each worker waits for circuits to scan for rewrites using all the // patterns and sends the results back to main. let joins: Vec<_> = (0..n_threads) .map(|i| { TasoWorker::spawn( - rx_work.clone(), - tx_result.clone(), + pq.pop.clone().unwrap(), + pq.push.clone().unwrap(), self.rewriter.clone(), self.strategy.clone(), Some(format!("taso-worker-{i}")), ) }) .collect(); - // Drop our copy of the worker channels, so we don't count as a - // connected worker. - drop(rx_work); - drop(tx_result); // Queue the initial circuit - tx_work + pq.push + .as_ref() + .unwrap() .send(vec![(initial_circ_hash, circ.clone())]) .unwrap(); + // Drop our copy of the priority queue channels, so we don't count as a + // connected worker. + pq.drop_pop_push(); - // A counter of circuits seen. - let mut circ_cnt = 1; - - // A counter of jobs sent to the workers. - #[allow(unused)] - let mut jobs_sent = 0usize; - // A counter of completed jobs received from the workers. - #[allow(unused)] - let mut jobs_completed = 0usize; // TODO: Report dropped jobs in the queue, so we can check for termination. // Deadline for the optimisation timeout @@ -225,66 +219,51 @@ where Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)), }; - // Process worker results until we have seen all the circuits, or we run - // out of time. + // Main loop: log best circuits as they come in from the priority queue, + // until the timeout is reached. let mut timeout_flag = false; loop { select! { - recv(rx_result) -> msg => { + recv(pq.log) -> msg => { match msg { - Ok(hashed_circs) => { - let send_result = tracing::trace_span!(target: "taso::metrics", "recv_result").in_scope(|| { - jobs_completed += 1; - for (circ_hash, circ) in &hashed_circs { - circ_cnt += 1; - logger.log_progress(circ_cnt, None, seen_hashes.len()); - if seen_hashes.contains(circ_hash) { - continue; - } - seen_hashes.insert(*circ_hash); - - let cost = (self.cost)(circ); - - // Check if we got a new best circuit - if cost < best_circ_cost { - best_circ = circ.clone(); - best_circ_cost = cost; - logger.log_best(best_circ_cost); - } - jobs_sent += 1; - } - // Fill the workqueue with data from pq - tx_work.send(hashed_circs) - }); - if send_result.is_err() { - eprintln!("All our workers panicked. Stopping optimisation."); - break; - } - - // If there is no more data to process, we are done. - // - // TODO: Report dropped jobs in the workers, so we can check for termination. - //if jobs_sent == jobs_completed { - // break 'main; - //}; + Ok(PriorityChannelLog::NewBestCircuit(circ, cost)) => { + best_circ = circ; + best_circ_cost = cost; + logger.log_best(best_circ_cost); }, + Ok(PriorityChannelLog::CircuitCount(circuit_cnt, seen_cnt)) => { + logger.log_progress(circuit_cnt, None, seen_cnt); + } Err(crossbeam_channel::RecvError) => { - eprintln!("All our workers panicked. Stopping optimisation."); + eprintln!("Priority queue panicked. Stopping optimisation."); break; } } } recv(timeout_event) -> _ => { timeout_flag = true; + pq.timeout(); break; } } } - logger.log_processing_end(circ_cnt, best_circ_cost, true, timeout_flag); + // Empty the log from the priority queue and store final circuit count. + let mut circuit_cnt = None; + while let Ok(log) = pq.log.recv() { + match log { + PriorityChannelLog::NewBestCircuit(circ, cost) => { + best_circ = circ; + best_circ_cost = cost; + logger.log_best(best_circ_cost); + } + PriorityChannelLog::CircuitCount(circ_cnt, _) => { + circuit_cnt = Some(circ_cnt); + } + } + } + logger.log_processing_end(circuit_cnt.unwrap_or(0), best_circ_cost, true, timeout_flag); - // Drop the channel so the threads know to stop. - drop(tx_work); joins.into_iter().for_each(|j| j.join().unwrap()); best_circ diff --git a/src/optimiser/taso/hugr_hash_set.rs b/src/optimiser/taso/hugr_hash_set.rs new file mode 100644 index 00000000..e710f32c --- /dev/null +++ b/src/optimiser/taso/hugr_hash_set.rs @@ -0,0 +1,155 @@ +use std::collections::VecDeque; + +use fxhash::FxHashSet; + +/// A datastructure storing Hugr hashes. +/// +/// Stores hashes in buckets based on a cost function, to allow clearing +/// the set of hashes that are no longer needed. +pub(super) struct HugrHashSet { + buckets: VecDeque>, + /// The cost at the front of the queue. + min_cost: Option, +} + +impl HugrHashSet { + /// Create a new empty set. + pub(super) fn new() -> Self { + Self { + buckets: VecDeque::new(), + min_cost: None, + } + } + + /// Create a new set with a single hash and cost. + pub(super) fn singleton(hash: u64, cost: usize) -> Self { + let mut set = Self::new(); + set.insert(hash, cost); + set + } + + /// Insert circuit with given hash and cost. + /// + /// Returns whether the insertion was successful, i.e. the negation + /// of whether it was already present. + pub(super) fn insert(&mut self, hash: u64, cost: usize) -> bool { + let Some(min_cost) = self.min_cost.as_mut() else { + self.min_cost = Some(cost); + self.buckets.push_front([hash].into_iter().collect()); + return true; + }; + self.buckets.reserve(min_cost.saturating_sub(cost)); + while cost < *min_cost { + self.buckets.push_front(FxHashSet::default()); + *min_cost -= 1; + } + let bucket_index = cost - *min_cost; + if bucket_index >= self.buckets.len() { + self.buckets + .resize_with(bucket_index + 1, FxHashSet::default); + } + self.buckets[bucket_index].insert(hash) + } + + /// Returns whether the given hash is present in the set. + #[allow(dead_code)] + pub(super) fn contains(&self, hash: u64, cost: usize) -> bool { + let Some(min_cost) = self.min_cost else { + return false; + }; + let Some(index) = cost.checked_sub(min_cost) else { + return false; + }; + let Some(b) = self.buckets.get(index) else { + return false; + }; + b.contains(&hash) + } + + fn max_cost(&self) -> Option { + Some(self.min_cost? + self.buckets.len() - 1) + } + + /// Remove all hashes with cost strictly greater than the given cost. + pub(super) fn clear_over(&mut self, cost: usize) { + while self.max_cost().is_some() && self.max_cost() > Some(cost) { + self.buckets.pop_back(); + if self.buckets.is_empty() { + self.min_cost = None; + } + } + } + + /// The number of hashes in the set + pub(super) fn len(&self) -> usize { + self.buckets.iter().map(|b| b.len()).sum() + } +} + +#[cfg(test)] +mod tests { + use super::HugrHashSet; + + #[test] + fn insert_elements() { + // For simplicity, we use as cost: hash % 10 + let mut set = HugrHashSet::new(); + + assert!(!set.contains(0, 0)); + assert!(!set.contains(2, 0)); + assert!(!set.contains(2, 3)); + + assert!(set.insert(20, 2)); + assert!(!set.contains(0, 0)); + assert!(!set.insert(20, 2)); + assert!(set.insert(22, 2)); + assert!(set.insert(23, 2)); + + assert!(set.contains(22, 2)); + + assert!(set.insert(33, 3)); + assert_eq!(set.min_cost, Some(2)); + assert_eq!(set.max_cost(), Some(3)); + assert_eq!( + set.buckets, + [ + [20, 22, 23].into_iter().collect(), + [33].into_iter().collect() + ] + ); + + assert!(set.insert(3, 0)); + assert!(set.insert(1, 0)); + assert!(!set.insert(22, 2)); + assert!(set.contains(33, 3)); + assert!(set.contains(3, 0)); + assert!(!set.contains(3, 2)); + assert_eq!(set.min_cost, Some(0)); + assert_eq!(set.max_cost(), Some(3)); + + assert_eq!(set.min_cost, Some(0)); + assert_eq!( + set.buckets, + [ + [1, 3].into_iter().collect(), + [].into_iter().collect(), + [20, 22, 23].into_iter().collect(), + [33].into_iter().collect(), + ] + ); + } + + #[test] + fn remove_empty() { + let mut set = HugrHashSet::new(); + assert!(set.insert(20, 2)); + assert!(set.insert(30, 3)); + + assert_eq!(set.len(), 2); + set.clear_over(2); + assert_eq!(set.len(), 1); + set.clear_over(0); + assert_eq!(set.len(), 0); + assert!(set.min_cost.is_none()); + } +} diff --git a/src/optimiser/taso/hugr_pchannel.rs b/src/optimiser/taso/hugr_pchannel.rs index 1ec0d2e4..b6b295cb 100644 --- a/src/optimiser/taso/hugr_pchannel.rs +++ b/src/optimiser/taso/hugr_pchannel.rs @@ -1,95 +1,213 @@ //! A multi-producer multi-consumer min-priority channel of Hugrs. -use std::marker::PhantomData; use std::thread; use crossbeam_channel::{select, Receiver, Sender}; use hugr::Hugr; -use super::hugr_pqueue::{Entry, HugrPQ}; +use super::{ + hugr_hash_set::HugrHashSet, + hugr_pqueue::{Entry, HugrPQ}, +}; /// A priority channel for HUGRs. /// /// Queues hugrs using a cost function `C` that produces priority values `P`. /// /// Uses a thread internally to orchestrate the queueing. -pub struct HugrPriorityChannel { - _phantom: PhantomData<(P, C)>, +pub(super) struct HugrPriorityChannel { + // Channels to add and remove circuits from the queue. + push: Receiver>, + pop: Sender<(u64, Hugr)>, + // Outbound channel to log to main thread. + log: Sender, + // Inbound channel to be terminated. + timeout: Receiver<()>, + // The queue capacity. Queue size is halved when it exceeds this. + queue_capacity: usize, + // The priority queue data structure. + pq: HugrPQ, + // The set of hashes we've seen. + seen_hashes: HugrHashSet, + // The minimum cost we've seen. + min_cost: Option, + // The number of circuits we've seen (for logging). + circ_cnt: usize, } -pub type Item = (u64, Hugr); +pub(super) type Item = (u64, Hugr); -impl HugrPriorityChannel +/// Logging information from the priority channel. +pub(super) enum PriorityChannelLog { + NewBestCircuit(Hugr, usize), + CircuitCount(usize, usize), +} + +/// Channels for communication with the priority channel. +pub(super) struct PriorityChannelCommunication { + pub(super) push: Option>>, + pub(super) pop: Option>, + pub(super) log: Receiver, + timeout: Sender<()>, +} + +impl PriorityChannelCommunication { + /// Send Timeout signal to the priority channel. + pub(super) fn timeout(&self) { + self.timeout.send(()).unwrap(); + } + + /// Close the local copies of the push and pop channels. + pub(super) fn drop_pop_push(&mut self) { + self.pop = None; + self.push = None; + } +} + +impl HugrPriorityChannel where - C: Fn(&Hugr) -> P + Send + Sync + 'static, + C: Fn(&Hugr) -> usize + Send + Sync + 'static, { /// Initialize the queueing system. /// - /// Get back a channel on which to queue hugrs with their hash, and - /// a channel on which to receive the output. - pub fn init(cost_fn: C, queue_capacity: usize) -> (Sender>, Receiver) { - let (ins, inr) = crossbeam_channel::unbounded(); - let (outs, outr) = crossbeam_channel::bounded(0); - Self::run(inr, outs, cost_fn, queue_capacity); - (ins, outr) + /// Start the Hugr priority queue in a new thread. + /// + /// Get back channels for communication with the priority queue + /// - push/pop channels for adding and removing circuits to/from the queue, + /// - a channel on which to receive logging information, and + /// - a channel on which to send a timeout signal. + pub(super) fn init(cost_fn: C, queue_capacity: usize) -> PriorityChannelCommunication { + // channels for pushing and popping circuits from pqueue + let (tx_push, rx_push) = crossbeam_channel::unbounded(); + let (tx_pop, rx_pop) = crossbeam_channel::bounded(0); + // channels for communication with main (logging, minimum circuits and timeout) + let (tx_log, rx_log) = crossbeam_channel::unbounded(); + let (tx_timeout, rx_timeout) = crossbeam_channel::bounded(0); + let pq = + HugrPriorityChannel::new(rx_push, tx_pop, tx_log, rx_timeout, cost_fn, queue_capacity); + pq.run(); + PriorityChannelCommunication { + push: Some(tx_push), + pop: Some(rx_pop), + log: rx_log, + timeout: tx_timeout, + } } - /// Run the queuer as a thread. - fn run( - in_channel_orig: Receiver>, - out_channel_orig: Sender<(u64, Hugr)>, + fn new( + push: Receiver>, + pop: Sender<(u64, Hugr)>, + log: Sender, + timeout: Receiver<()>, cost_fn: C, queue_capacity: usize, - ) { + ) -> Self { + // The priority queue, local to this thread. + let pq: HugrPQ = HugrPQ::with_capacity(cost_fn, queue_capacity); + // The set of hashes we've seen. + let seen_hashes = HugrHashSet::new(); + // The minimum cost we've seen. + let min_cost = None; + // The number of circuits we've seen (for logging). + let circ_cnt = 0; + + HugrPriorityChannel { + push, + pop, + log, + timeout, + queue_capacity, + pq, + seen_hashes, + min_cost, + circ_cnt, + } + } + + /// Run the queuer as a thread. + fn run(mut self) { let builder = thread::Builder::new().name("priority queueing".into()); - let in_channel = in_channel_orig.clone(); - let out_channel = out_channel_orig.clone(); let _ = builder .name("priority-channel".into()) .spawn(move || { - // The priority queue, local to this thread. - let mut pq: HugrPQ = - HugrPQ::with_capacity(cost_fn, queue_capacity); - loop { - if pq.is_empty() { - // Nothing queued to go out. Wait for input. - match in_channel.recv() { - Ok(new_circs) => { - for (hash, circ) in new_circs { - pq.push_with_hash_unchecked(circ, hash); - } - } - // The sender has closed the channel, we can stop. - Err(_) => break, - } + if self.pq.is_empty() { + let Ok(new_circs) = self.push.recv() else { + // The senders have closed the channel, we can stop. + break; + }; + self.recv(new_circs); } else { select! { - recv(in_channel) -> result => { - match result { - Ok(new_circs) => { - for (hash, circ) in new_circs { - pq.push_with_hash_unchecked(circ, hash); - } - } - // The sender has closed the channel, we can stop. - Err(_) => break, - } + recv(self.push) -> result => { + let Ok(new_circs) = result else { + // The senders have closed the channel, we can stop. + break; + }; + self.recv(new_circs); } - send(out_channel, {let Entry {hash, circ, ..} = pq.pop().unwrap(); (hash, circ)}) -> result => { + send(self.pop, {let Entry {hash, circ, ..} = self.pq.pop().unwrap(); (hash, circ)}) -> result => { match result { Ok(()) => {}, // The receivers have closed the channel, we can stop. Err(_) => break, } } + recv(self.timeout) -> _ => { + // We've timed out. + break + } } } - if pq.len() >= queue_capacity { - pq.truncate(queue_capacity / 2); - } } + // Send a last set of logs before terminating. + self.log + .send(PriorityChannelLog::CircuitCount( + self.circ_cnt, + self.seen_hashes.len(), + )) + .unwrap(); }) .unwrap(); } + + /// Add circuits to queue. + fn recv(&mut self, circs: Vec<(u64, Hugr)>) { + for (hash, circ) in circs { + let cost = (self.pq.cost_fn)(&circ); + if (self.pq.len() > self.queue_capacity / 2 && cost > *self.pq.max_cost().unwrap()) + || !self.seen_hashes.insert(hash, cost) + { + // Ignore this circuit: it's either too big or we've seen it before. + continue; + } + + // A new best circuit + if self.min_cost.is_none() || Some(cost) < self.min_cost { + self.min_cost = Some(cost); + self.log + .send(PriorityChannelLog::NewBestCircuit(circ.clone(), cost)) + .unwrap(); + } + + self.circ_cnt += 1; + self.pq.push_unchecked(circ, hash, cost); + + // Send logs every 1000 circuits. + if self.circ_cnt % 1000 == 0 { + // TODO: Add a minimum time between logs + self.log + .send(PriorityChannelLog::CircuitCount( + self.circ_cnt, + self.seen_hashes.len(), + )) + .unwrap(); + } + } + // If the queue got too big, truncate it. + if self.pq.len() >= self.queue_capacity { + self.pq.truncate(self.queue_capacity / 2); + self.seen_hashes.clear_over(*self.pq.max_cost().unwrap()); + } + } } diff --git a/src/optimiser/taso/hugr_pqueue.rs b/src/optimiser/taso/hugr_pqueue.rs index eba000e0..7dd532cc 100644 --- a/src/optimiser/taso/hugr_pqueue.rs +++ b/src/optimiser/taso/hugr_pqueue.rs @@ -13,7 +13,7 @@ use crate::circuit::CircuitHash; pub(super) struct HugrPQ { queue: DoublePriorityQueue, hash_lookup: FxHashMap, - cost_fn: C, + pub(super) cost_fn: C, } pub(super) struct Entry { @@ -51,20 +51,20 @@ impl HugrPQ { C: Fn(&Hugr) -> P, { let hash = hugr.circuit_hash(); - self.push_with_hash_unchecked(hugr, hash); + let cost = (self.cost_fn)(&hugr); + self.push_unchecked(hugr, hash, cost); } - /// Push a Hugr into the queue with a precomputed hash. + /// Push a Hugr into the queue with a precomputed hash and cost. /// - /// This is useful to avoid recomputing the hash in [`HugrPQ::push`] when - /// it is already known. + /// This is useful to avoid recomputing the hash and cost function in + /// [`HugrPQ::push`] when they are already known. /// /// This does not check that the hash is valid. - pub(super) fn push_with_hash_unchecked(&mut self, hugr: Hugr, hash: u64) + pub(super) fn push_unchecked(&mut self, hugr: Hugr, hash: u64, cost: P) where C: Fn(&Hugr) -> P, { - let cost = (self.cost_fn)(&hugr); self.queue.push(hash, cost); self.hash_lookup.insert(hash, hugr); } @@ -85,6 +85,11 @@ impl HugrPQ { } } + /// The largest cost in the queue. + pub(super) fn max_cost(&self) -> Option<&P> { + self.queue.peek_max().map(|(_, cost)| cost) + } + delegate! { to self.queue { pub(super) fn len(&self) -> usize; diff --git a/src/optimiser/taso/worker.rs b/src/optimiser/taso/worker.rs index 0d8667c8..a2aadbf7 100644 --- a/src/optimiser/taso/worker.rs +++ b/src/optimiser/taso/worker.rs @@ -51,7 +51,7 @@ where let send = tracing::trace_span!(target: "taso::metrics", "TasoWorker::send_result") .in_scope(|| tx_result.send(hashed_circs)); if send.is_err() { - // The main thread closed the send channel, we can stop. + // The priority queue closed the send channel, we can stop. break; } } From 31ecafa39edd4cae0f12602a8712f6813d172e7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:48:49 +0200 Subject: [PATCH 2/3] ci: Disable failing tkcxx test workflows (#139) The tests that compile `tket-rs` are currently failing in CI with a linking error. This started failing on [13/09/2023](https://github.com/CQCL-DEV/tket2/actions/runs/6173273334), but running the workflow on older revisions also fails. I couldn't reproduce them locally so I'll just disable them temporarily. We may want to try some nix-based workflow for this, to ensure upstream updates don't break our builds... --- .github/workflows/with-bindings.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/with-bindings.yml b/.github/workflows/with-bindings.yml index b1f72887..9ec6d549 100644 --- a/.github/workflows/with-bindings.yml +++ b/.github/workflows/with-bindings.yml @@ -1,9 +1,10 @@ name: Run tests with TKET1 bindings on: - push: - branches: - - main + # Disabled due to https://github.com/CQCL-DEV/tket2/issues/111 + #push: + # branches: + # - main workflow_dispatch: {} env: From d8fce779ad04e6ec22d7714e2c02926ca9a552a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 27 Sep 2023 18:38:10 +0200 Subject: [PATCH 3/3] feat: Split & reassemble circuit chunks (#130) This is a pass utility to split circuits into chunks that can be independently optimized. Closes #129 --- Cargo.toml | 2 +- pyrs/src/circuit.rs | 36 ++++ pyrs/src/lib.rs | 26 +-- pyrs/test/test_bindings.py | 13 +- src/lib.rs | 4 +- src/passes.rs | 5 +- src/passes/chunks.rs | 339 +++++++++++++++++++++++++++++++++++++ 7 files changed, 396 insertions(+), 29 deletions(-) create mode 100644 src/passes/chunks.rs diff --git a/Cargo.toml b/Cargo.toml index f7a52f64..4e19f896 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,7 +73,7 @@ members = ["pyrs", "compile-matcher", "taso-optimiser"] [workspace.dependencies] -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "19ed0fc" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "af664e3" } portgraph = { version = "0.9", features = ["serde"] } pyo3 = { version = "0.19" } itertools = { version = "0.11.0" } diff --git a/pyrs/src/circuit.rs b/pyrs/src/circuit.rs index 850278ff..dbb6db97 100644 --- a/pyrs/src/circuit.rs +++ b/pyrs/src/circuit.rs @@ -6,6 +6,7 @@ use pyo3::prelude::*; use hugr::{Hugr, HugrView}; use tket2::extension::REGISTRY; use tket2::json::TKETDecode; +use tket2::passes::CircuitChunks; use tket_json_rs::circuit_json::SerialCircuit; /// Apply a fallible function expecting a hugr on a pytket circuit. @@ -52,3 +53,38 @@ pub fn to_hugr_dot(c: Py) -> PyResult { pub fn to_hugr(c: Py) -> PyResult { with_hugr(c, |hugr| hugr) } + +#[pyfunction] +pub fn chunks(c: Py, max_chunk_size: usize) -> PyResult { + with_hugr(c, |hugr| CircuitChunks::split(&hugr, max_chunk_size)) +} + +/// circuit module +pub fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> { + let m = PyModule::new(py, "circuit")?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + m.add_function(wrap_pyfunction!(validate_hugr, m)?)?; + m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?; + m.add_function(wrap_pyfunction!(to_hugr, m)?)?; + m.add_function(wrap_pyfunction!(chunks, m)?)?; + + m.add("HugrError", py.get_type::())?; + m.add("BuildError", py.get_type::())?; + m.add( + "ValidationError", + py.get_type::(), + )?; + m.add( + "HUGRSerializationError", + py.get_type::(), + )?; + m.add( + "OpConvertError", + py.get_type::(), + )?; + + parent.add_submodule(m) +} diff --git a/pyrs/src/lib.rs b/pyrs/src/lib.rs index 146c36f7..062b4ba4 100644 --- a/pyrs/src/lib.rs +++ b/pyrs/src/lib.rs @@ -1,6 +1,6 @@ //! Python bindings for TKET2. #![warn(missing_docs)] -use circuit::try_with_hugr; +use circuit::{add_circuit_module, try_with_hugr}; use pyo3::prelude::*; use tket2::{json::TKETDecode, passes::apply_greedy_commutation}; use tket_json_rs::circuit_json::SerialCircuit; @@ -25,30 +25,6 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> { Ok(()) } -/// circuit module -fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> { - let m = PyModule::new(py, "circuit")?; - m.add_class::()?; - m.add_class::()?; - - m.add("HugrError", py.get_type::())?; - m.add("BuildError", py.get_type::())?; - m.add( - "ValidationError", - py.get_type::(), - )?; - m.add( - "HUGRSerializationError", - py.get_type::(), - )?; - m.add( - "OpConvertError", - py.get_type::(), - )?; - - parent.add_submodule(m) -} - /// portmatching module fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> { let m = PyModule::new(py, "pattern")?; diff --git a/pyrs/test/test_bindings.py b/pyrs/test/test_bindings.py index 211b0e51..a43263e5 100644 --- a/pyrs/test/test_bindings.py +++ b/pyrs/test/test_bindings.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from pyrs.pyrs import passes +from pyrs.pyrs import passes, circuit from pytket.circuit import Circuit @@ -19,6 +19,17 @@ def test_depth_optimise(): assert c.depth() == 2 +def test_chunks(): + c = Circuit(4).CX(0, 2).CX(1, 3).CX(1, 2).CX(0, 3).CX(1, 3) + + assert c.depth() == 3 + + chunks = circuit.chunks(c, 2) + circuits = chunks.circuits() + chunks.update_circuit(0, circuits[0]) + c2 = chunks.reassemble() + + assert c2.depth() == 3 # from dataclasses import dataclass # from typing import Callable, Iterable diff --git a/src/lib.rs b/src/lib.rs index 3e89fda1..af4ec438 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,9 +14,11 @@ pub(crate) mod ops; pub mod optimiser; pub mod passes; pub mod rewrite; -pub use ops::{symbolic_constant_op, Pauli, T2Op}; #[cfg(feature = "portmatching")] pub mod portmatching; mod utils; + +pub use circuit::Circuit; +pub use ops::{symbolic_constant_op, Pauli, T2Op}; diff --git a/src/passes.rs b/src/passes.rs index 68ac2a44..efacfec5 100644 --- a/src/passes.rs +++ b/src/passes.rs @@ -1,6 +1,9 @@ -//! Optimisation passes for circuits. +//! Optimisation passes and related utilities for circuits. mod commutation; pub use commutation::apply_greedy_commutation; #[cfg(feature = "pyo3")] pub use commutation::PyPullForwardError; + +pub mod chunks; +pub use chunks::CircuitChunks; diff --git a/src/passes/chunks.rs b/src/passes/chunks.rs new file mode 100644 index 00000000..ff10d3b7 --- /dev/null +++ b/src/passes/chunks.rs @@ -0,0 +1,339 @@ +//! Utility + +use std::collections::HashMap; + +use hugr::builder::{Dataflow, DataflowHugr, FunctionBuilder}; +use hugr::extension::ExtensionSet; +use hugr::hugr::hugrmut::HugrMut; +use hugr::hugr::views::sibling_subgraph::ConvexChecker; +use hugr::hugr::views::{HierarchyView, SiblingGraph, SiblingSubgraph}; +use hugr::hugr::{HugrError, NodeMetadata}; +use hugr::ops::handle::DataflowParentID; +use hugr::types::{FunctionType, Signature}; +use hugr::{Hugr, HugrView, Node, Port, Wire}; +use itertools::Itertools; + +use crate::extension::REGISTRY; +use crate::Circuit; + +#[cfg(feature = "pyo3")] +use crate::json::TKETDecode; +#[cfg(feature = "pyo3")] +use pyo3::{exceptions::PyAttributeError, pyclass, pymethods, Py, PyAny, PyResult}; +#[cfg(feature = "pyo3")] +use tket_json_rs::circuit_json::SerialCircuit; + +/// An identifier for the connection between chunks. +/// +/// This is based on the wires of the original circuit. +/// +/// When reassembling the circuit, the input/output wires of each chunk are +/// re-linked by matching these identifiers. +pub type ChunkConnection = Wire; + +/// A chunk of a circuit. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "pyo3", pyclass)] +pub struct Chunk { + /// The extracted circuit. + pub circ: Hugr, + /// The original wires connected to the input. + pub inputs: Vec, + /// The original wires connected to the output. + pub outputs: Vec, +} + +impl Chunk { + /// Extract a chunk from a circuit. + /// + /// The chunk is extracted from the input wires to the output wires. + pub(self) fn extract<'h, H: HugrView>( + circ: &'h H, + nodes: impl IntoIterator, + checker: &mut ConvexChecker<'h, H>, + ) -> Self { + let subgraph = SiblingSubgraph::try_from_nodes_with_checker( + nodes.into_iter().collect_vec(), + circ, + checker, + ) + .expect("Failed to define the chunk subgraph"); + let extracted = subgraph + .extract_subgraph(circ, "Chunk", ExtensionSet::new()) + .expect("Failed to extract chunk"); + // Transform the subgraph's input/output sets into wires that can be + // matched between different chunks. + // + // This requires finding the `Outgoing` port corresponding to each + // subgraph input. + let inputs = subgraph + .incoming_ports() + .iter() + .map(|wires| { + let (inp_node, inp_port) = wires[0]; + let (out_node, out_port) = circ + .linked_ports(inp_node, inp_port) + .exactly_one() + .ok() + .unwrap(); + Wire::new(out_node, out_port) + }) + .collect(); + let outputs = subgraph + .outgoing_ports() + .iter() + .map(|&(node, port)| Wire::new(node, port)) + .collect(); + Self { + circ: extracted, + inputs, + outputs, + } + } + + /// Insert the chunk back into a circuit. + // + // TODO: The new chunk may have input ports directly connected to outputs. We have to take care of those. + #[allow(clippy::type_complexity)] + pub(self) fn insert(&self, circ: &mut impl HugrMut, root: Node) -> ChunkInsertResult { + let chunk_sg: SiblingGraph<'_, DataflowParentID> = + SiblingGraph::try_new(&self.circ, self.circ.root()).unwrap(); + let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&chunk_sg) + .expect("The chunk circuit is no longer a dataflow"); + let node_map = circ + .insert_subgraph(root, &self.circ, &subgraph) + .expect("Failed to insert the chunk subgraph") + .node_map; + + let [inp, out] = circ.get_io(root).unwrap(); + let mut input_map = HashMap::with_capacity(self.inputs.len()); + let mut output_map = HashMap::with_capacity(self.outputs.len()); + + for (&connection, incoming) in self.inputs.iter().zip(subgraph.incoming_ports().iter()) { + let incoming = incoming.iter().map(|&(node, port)| { + if node == out { + // TODO: Add a map for directly connected Input connection -> Output Wire. + panic!("Chunk input directly connected to the output. This is not currently supported."); + } + (*node_map.get(&node).unwrap(),port) + }).collect_vec(); + input_map.insert(connection, incoming); + } + + for (&wire, &(node, port)) in self.outputs.iter().zip(subgraph.outgoing_ports().iter()) { + if node == inp { + // TODO: Add a map for directly connected Input Wire -> Output Wire. + panic!("Chunk input directly connected to the output. This is not currently supported."); + } + output_map.insert(wire, (*node_map.get(&node).unwrap(), port)); + } + + ChunkInsertResult { + incoming_connections: input_map, + outgoing_connections: output_map, + } + } +} + +/// A map from the original input/output [`ChunkConnection`]s to an inserted chunk's inputs and outputs. +struct ChunkInsertResult { + /// A map from incoming connections to a chunk, to the new node and incoming port targets. + /// + /// A chunk may specify multiple targets to be connected to a single incoming `ChunkConnection`. + pub incoming_connections: HashMap>, + /// A map from outgoing connections from a chunk, to the new node and outgoing port target. + pub outgoing_connections: HashMap, +} + +/// An utility for splitting a circuit into chunks, and reassembling them afterwards. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "pyo3", pyclass)] +pub struct CircuitChunks { + /// The original circuit's signature. + signature: FunctionType, + + /// The original circuit's root metadata. + root_meta: NodeMetadata, + + /// The original circuit's inputs. + input_connections: Vec, + + /// The original circuit's outputs. + output_connections: Vec, + + /// The split circuits. + pub chunks: Vec, +} + +impl CircuitChunks { + /// Split a circuit into chunks. + /// + /// The circuit is split into chunks of at most `max_size` gates. + pub fn split(circ: &impl Circuit, max_size: usize) -> Self { + let root_meta = circ.get_metadata(circ.root()).clone(); + let signature = circ.circuit_signature().clone(); + + let [circ_input, circ_output] = circ.get_io(circ.root()).unwrap(); + let input_connections = circ + .node_outputs(circ_input) + .map(|port| Wire::new(circ_input, port)) + .collect(); + let output_connections = circ + .node_inputs(circ_output) + .flat_map(|p| circ.linked_ports(circ_output, p)) + .map(|(n, p)| Wire::new(n, p)) + .collect(); + + let mut chunks = Vec::new(); + let mut convex_checker = ConvexChecker::new(circ); + for commands in &circ.commands().map(|cmd| cmd.node()).chunks(max_size) { + chunks.push(Chunk::extract(circ, commands, &mut convex_checker)); + } + Self { + signature, + root_meta, + input_connections, + output_connections, + chunks, + } + } + + /// Reassemble the chunks into a circuit. + pub fn reassemble(self) -> Result { + let name = self + .root_meta + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let signature = Signature { + signature: self.signature, + // TODO: Is this correct? Can a circuit root have a fixed set of input extensions? + input_extensions: ExtensionSet::new(), + }; + + let builder = FunctionBuilder::new(name, signature).unwrap(); + let inputs = builder.input_wires(); + // TODO: Use the correct REGISTRY if the method accepts custom input resources. + let mut reassembled = builder.finish_hugr_with_outputs(inputs, ®ISTRY).unwrap(); + let root = reassembled.root(); + let [reassembled_input, reassembled_output] = reassembled.get_io(root).unwrap(); + + // The chunks input and outputs are each identified with a + // [`ChunkConnection`]. We collect both sides first, and rewire them + // after the chunks have been inserted. + let mut sources: HashMap = HashMap::new(); + let mut targets: HashMap> = HashMap::new(); + + for (&connection, port) in self + .input_connections + .iter() + .zip(reassembled.node_outputs(reassembled_input)) + { + reassembled.disconnect(reassembled_input, port)?; + sources.insert(connection, (reassembled_input, port)); + } + for (&connection, port) in self + .output_connections + .iter() + .zip(reassembled.node_inputs(reassembled_output)) + { + targets.insert(connection, vec![(reassembled_output, port)]); + } + + for chunk in self.chunks { + // Insert the chunk circuit without its input/output nodes. + let ChunkInsertResult { + incoming_connections, + outgoing_connections, + } = chunk.insert(&mut reassembled, root); + // Reconnect the chunk's inputs and outputs in the reassembled circuit. + sources.extend(outgoing_connections); + incoming_connections.into_iter().for_each(|(wire, tgts)| { + targets.entry(wire).or_default().extend(tgts); + }); + } + + // Reconnect the different chunks. + for (connection, (source, source_port)) in sources { + let Some(tgts) = targets.remove(&connection) else { + continue; + }; + for (target, target_port) in tgts { + reassembled.connect(source, source_port, target, target_port)?; + } + } + + Ok(reassembled) + } + + /// Returns a list of references to the split circuits. + pub fn circuits(&self) -> impl Iterator { + self.chunks.iter().map(|chunk| &chunk.circ) + } +} + +#[cfg(feature = "pyo3")] +#[pymethods] +impl CircuitChunks { + /// Reassemble the chunks into a circuit. + #[pyo3(name = "reassemble")] + fn py_reassemble(&self) -> PyResult> { + let hugr = self.clone().reassemble()?; + SerialCircuit::encode(&hugr)?.to_tket1() + } + + /// Returns clones of the split circuits. + #[pyo3(name = "circuits")] + fn py_circuits(&self) -> PyResult>> { + self.circuits() + .map(|hugr| SerialCircuit::encode(hugr)?.to_tket1()) + .collect() + } + + /// Replaces a chunk's circuit with an updated version. + #[pyo3(name = "update_circuit")] + fn py_update_circuit(&mut self, index: usize, new_circ: Py) -> PyResult<()> { + let hugr = SerialCircuit::_from_tket1(new_circ).decode()?; + if hugr.circuit_signature() != self.chunks[index].circ.circuit_signature() { + return Err(PyAttributeError::new_err( + "The new circuit has a different signature.", + )); + } + self.chunks[index].circ = hugr; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use crate::circuit::CircuitHash; + use crate::utils::build_simple_circuit; + use crate::T2Op; + + use super::*; + + #[test] + fn split_reassemble() { + let circ = build_simple_circuit(2, |circ| { + circ.append(T2Op::H, [0])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::T, [1])?; + circ.append(T2Op::H, [0])?; + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::H, [0])?; + circ.append(T2Op::CX, [0, 1])?; + Ok(()) + }) + .unwrap(); + + let mut chunks = CircuitChunks::split(&circ, 3); + + // Rearrange the chunks so nodes are inserted in a new order. + chunks.chunks.reverse(); + + let mut reassembled = chunks.reassemble().unwrap(); + + reassembled.infer_and_validate(®ISTRY).unwrap(); + assert_eq!(circ.circuit_hash(), reassembled.circuit_hash()); + } +}