Skip to content

Commit

Permalink
Merge branch 'main' into feat/circuit-mut
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada authored Sep 27, 2023
2 parents 13213d5 + e33fc6a commit 7c153d7
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 130 deletions.
123 changes: 51 additions & 72 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//! it gets too large.
mod eq_circ_class;
mod hugr_hash_set;
mod hugr_pchannel;
mod hugr_pqueue;
pub mod log;
Expand All @@ -24,11 +25,11 @@ pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
use std::num::NonZeroUsize;
use std::time::{Duration, Instant};

use fxhash::FxHashSet;
use hugr::Hugr;

use crate::circuit::CircuitHash;
use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel;
use crate::optimiser::taso::hugr_hash_set::HugrHashSet;
use crate::optimiser::taso::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog};
use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::taso::worker::TasoWorker;
use crate::rewrite::strategy::RewriteStrategy;
Expand Down Expand Up @@ -111,7 +112,7 @@ where
logger.log_best(best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]);
let mut seen_hashes = HugrHashSet::singleton(circ.circuit_hash(), best_circ_cost);

// The priority queue of circuits to be processed (this should not get big)
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;
Expand All @@ -129,19 +130,26 @@ where

let rewrites = self.rewriter.get_rewrites(&circ);
for new_circ in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = (self.cost)(&new_circ);
if pq.len() > PRIORITY_QUEUE_CAPACITY / 2 && new_circ_cost > *pq.max_cost().unwrap()
{
// Ignore this circuit: it's too big
continue;
}
let new_circ_hash = new_circ.circuit_hash();
circ_cnt += 1;
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
if seen_hashes.contains(&new_circ_hash) {
if !seen_hashes.insert(new_circ_hash, new_circ_cost) {
// Ignore this circuit: we've already seen it
continue;
}
pq.push_with_hash_unchecked(new_circ, new_circ_hash);
seen_hashes.insert(new_circ_hash);
circ_cnt += 1;
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);
}

if pq.len() >= PRIORITY_QUEUE_CAPACITY {
// Haircut to keep the queue size manageable
pq.truncate(PRIORITY_QUEUE_CAPACITY / 2);
seen_hashes.clear_over(*pq.max_cost().unwrap());
}

if let Some(timeout) = timeout {
Expand Down Expand Up @@ -172,51 +180,37 @@ where
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

// multi-consumer priority channel for queuing circuits to be processed by the workers
let (tx_work, rx_work) =
let mut pq =
HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads);
// channel for sending circuits from threads back to main
let (tx_result, rx_result) = crossbeam_channel::unbounded();

let initial_circ_hash = circ.circuit_hash();
let mut best_circ = circ.clone();
let mut best_circ_cost = (self.cost)(&best_circ);
logger.log_best(best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]);

// Each worker waits for circuits to scan for rewrites using all the
// patterns and sends the results back to main.
let joins: Vec<_> = (0..n_threads)
.map(|i| {
TasoWorker::spawn(
rx_work.clone(),
tx_result.clone(),
pq.pop.clone().unwrap(),
pq.push.clone().unwrap(),
self.rewriter.clone(),
self.strategy.clone(),
Some(format!("taso-worker-{i}")),
)
})
.collect();
// Drop our copy of the worker channels, so we don't count as a
// connected worker.
drop(rx_work);
drop(tx_result);

// Queue the initial circuit
tx_work
pq.push
.as_ref()
.unwrap()
.send(vec![(initial_circ_hash, circ.clone())])
.unwrap();
// Drop our copy of the priority queue channels, so we don't count as a
// connected worker.
pq.drop_pop_push();

// A counter of circuits seen.
let mut circ_cnt = 1;

// A counter of jobs sent to the workers.
#[allow(unused)]
let mut jobs_sent = 0usize;
// A counter of completed jobs received from the workers.
#[allow(unused)]
let mut jobs_completed = 0usize;
// TODO: Report dropped jobs in the queue, so we can check for termination.

// Deadline for the optimisation timeout
Expand All @@ -225,66 +219,51 @@ where
Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)),
};

// Process worker results until we have seen all the circuits, or we run
// out of time.
// Main loop: log best circuits as they come in from the priority queue,
// until the timeout is reached.
let mut timeout_flag = false;
loop {
select! {
recv(rx_result) -> msg => {
recv(pq.log) -> msg => {
match msg {
Ok(hashed_circs) => {
let send_result = tracing::trace_span!(target: "taso::metrics", "recv_result").in_scope(|| {
jobs_completed += 1;
for (circ_hash, circ) in &hashed_circs {
circ_cnt += 1;
logger.log_progress(circ_cnt, None, seen_hashes.len());
if seen_hashes.contains(circ_hash) {
continue;
}
seen_hashes.insert(*circ_hash);

let cost = (self.cost)(circ);

// Check if we got a new best circuit
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost;
logger.log_best(best_circ_cost);
}
jobs_sent += 1;
}
// Fill the workqueue with data from pq
tx_work.send(hashed_circs)
});
if send_result.is_err() {
eprintln!("All our workers panicked. Stopping optimisation.");
break;
}

// If there is no more data to process, we are done.
//
// TODO: Report dropped jobs in the workers, so we can check for termination.
//if jobs_sent == jobs_completed {
// break 'main;
//};
Ok(PriorityChannelLog::NewBestCircuit(circ, cost)) => {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(best_circ_cost);
},
Ok(PriorityChannelLog::CircuitCount(circuit_cnt, seen_cnt)) => {
logger.log_progress(circuit_cnt, None, seen_cnt);
}
Err(crossbeam_channel::RecvError) => {
eprintln!("All our workers panicked. Stopping optimisation.");
eprintln!("Priority queue panicked. Stopping optimisation.");
break;
}
}
}
recv(timeout_event) -> _ => {
timeout_flag = true;
pq.timeout();
break;
}
}
}

logger.log_processing_end(circ_cnt, best_circ_cost, true, timeout_flag);
// Empty the log from the priority queue and store final circuit count.
let mut circuit_cnt = None;
while let Ok(log) = pq.log.recv() {
match log {
PriorityChannelLog::NewBestCircuit(circ, cost) => {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(best_circ_cost);
}
PriorityChannelLog::CircuitCount(circ_cnt, _) => {
circuit_cnt = Some(circ_cnt);
}
}
}
logger.log_processing_end(circuit_cnt.unwrap_or(0), best_circ_cost, true, timeout_flag);

// Drop the channel so the threads know to stop.
drop(tx_work);
joins.into_iter().for_each(|j| j.join().unwrap());

best_circ
Expand Down
155 changes: 155 additions & 0 deletions src/optimiser/taso/hugr_hash_set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use std::collections::VecDeque;

use fxhash::FxHashSet;

/// A datastructure storing Hugr hashes.
///
/// Stores hashes in buckets based on a cost function, to allow clearing
/// the set of hashes that are no longer needed.
pub(super) struct HugrHashSet {
buckets: VecDeque<FxHashSet<u64>>,
/// The cost at the front of the queue.
min_cost: Option<usize>,
}

impl HugrHashSet {
/// Create a new empty set.
pub(super) fn new() -> Self {
Self {
buckets: VecDeque::new(),
min_cost: None,
}
}

/// Create a new set with a single hash and cost.
pub(super) fn singleton(hash: u64, cost: usize) -> Self {
let mut set = Self::new();
set.insert(hash, cost);
set
}

/// Insert circuit with given hash and cost.
///
/// Returns whether the insertion was successful, i.e. the negation
/// of whether it was already present.
pub(super) fn insert(&mut self, hash: u64, cost: usize) -> bool {
let Some(min_cost) = self.min_cost.as_mut() else {
self.min_cost = Some(cost);
self.buckets.push_front([hash].into_iter().collect());
return true;
};
self.buckets.reserve(min_cost.saturating_sub(cost));
while cost < *min_cost {
self.buckets.push_front(FxHashSet::default());
*min_cost -= 1;
}
let bucket_index = cost - *min_cost;
if bucket_index >= self.buckets.len() {
self.buckets
.resize_with(bucket_index + 1, FxHashSet::default);
}
self.buckets[bucket_index].insert(hash)
}

/// Returns whether the given hash is present in the set.
#[allow(dead_code)]
pub(super) fn contains(&self, hash: u64, cost: usize) -> bool {
let Some(min_cost) = self.min_cost else {
return false;
};
let Some(index) = cost.checked_sub(min_cost) else {
return false;
};
let Some(b) = self.buckets.get(index) else {
return false;
};
b.contains(&hash)
}

fn max_cost(&self) -> Option<usize> {
Some(self.min_cost? + self.buckets.len() - 1)
}

/// Remove all hashes with cost strictly greater than the given cost.
pub(super) fn clear_over(&mut self, cost: usize) {
while self.max_cost().is_some() && self.max_cost() > Some(cost) {
self.buckets.pop_back();
if self.buckets.is_empty() {
self.min_cost = None;
}
}
}

/// The number of hashes in the set
pub(super) fn len(&self) -> usize {
self.buckets.iter().map(|b| b.len()).sum()
}
}

#[cfg(test)]
mod tests {
use super::HugrHashSet;

#[test]
fn insert_elements() {
// For simplicity, we use as cost: hash % 10
let mut set = HugrHashSet::new();

assert!(!set.contains(0, 0));
assert!(!set.contains(2, 0));
assert!(!set.contains(2, 3));

assert!(set.insert(20, 2));
assert!(!set.contains(0, 0));
assert!(!set.insert(20, 2));
assert!(set.insert(22, 2));
assert!(set.insert(23, 2));

assert!(set.contains(22, 2));

assert!(set.insert(33, 3));
assert_eq!(set.min_cost, Some(2));
assert_eq!(set.max_cost(), Some(3));
assert_eq!(
set.buckets,
[
[20, 22, 23].into_iter().collect(),
[33].into_iter().collect()
]
);

assert!(set.insert(3, 0));
assert!(set.insert(1, 0));
assert!(!set.insert(22, 2));
assert!(set.contains(33, 3));
assert!(set.contains(3, 0));
assert!(!set.contains(3, 2));
assert_eq!(set.min_cost, Some(0));
assert_eq!(set.max_cost(), Some(3));

assert_eq!(set.min_cost, Some(0));
assert_eq!(
set.buckets,
[
[1, 3].into_iter().collect(),
[].into_iter().collect(),
[20, 22, 23].into_iter().collect(),
[33].into_iter().collect(),
]
);
}

#[test]
fn remove_empty() {
let mut set = HugrHashSet::new();
assert!(set.insert(20, 2));
assert!(set.insert(30, 3));

assert_eq!(set.len(), 2);
set.clear_over(2);
assert_eq!(set.len(), 1);
set.clear_over(0);
assert_eq!(set.len(), 0);
assert!(set.min_cost.is_none());
}
}
Loading

0 comments on commit 7c153d7

Please sign in to comment.