Skip to content

Commit

Permalink
revert: Reduce TASO hashtable size (#144)
Browse files Browse the repository at this point in the history
This reverts commit e33fc6a.
  • Loading branch information
lmondada authored Sep 28, 2023
1 parent 7c25d83 commit 387cae0
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 387 deletions.
123 changes: 72 additions & 51 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
//! it gets too large.
mod eq_circ_class;
mod hugr_hash_set;
mod hugr_pchannel;
mod hugr_pqueue;
pub mod log;
Expand All @@ -25,11 +24,11 @@ pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
use std::num::NonZeroUsize;
use std::time::{Duration, Instant};

use fxhash::FxHashSet;
use hugr::Hugr;

use crate::circuit::CircuitHash;
use crate::optimiser::taso::hugr_hash_set::HugrHashSet;
use crate::optimiser::taso::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog};
use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel;
use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::taso::worker::TasoWorker;
use crate::rewrite::strategy::RewriteStrategy;
Expand Down Expand Up @@ -112,7 +111,7 @@ where
logger.log_best(best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes = HugrHashSet::singleton(circ.circuit_hash(), best_circ_cost);
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]);

// The priority queue of circuits to be processed (this should not get big)
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;
Expand All @@ -130,26 +129,19 @@ where

let rewrites = self.rewriter.get_rewrites(&circ);
for new_circ in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = (self.cost)(&new_circ);
if pq.len() > PRIORITY_QUEUE_CAPACITY / 2 && new_circ_cost > *pq.max_cost().unwrap()
{
// Ignore this circuit: it's too big
continue;
}
let new_circ_hash = new_circ.circuit_hash();
circ_cnt += 1;
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
if !seen_hashes.insert(new_circ_hash, new_circ_cost) {
// Ignore this circuit: we've already seen it
if seen_hashes.contains(&new_circ_hash) {
continue;
}
circ_cnt += 1;
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);
pq.push_with_hash_unchecked(new_circ, new_circ_hash);
seen_hashes.insert(new_circ_hash);
}

if pq.len() >= PRIORITY_QUEUE_CAPACITY {
// Haircut to keep the queue size manageable
pq.truncate(PRIORITY_QUEUE_CAPACITY / 2);
seen_hashes.clear_over(*pq.max_cost().unwrap());
}

if let Some(timeout) = timeout {
Expand Down Expand Up @@ -180,37 +172,51 @@ where
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

// multi-consumer priority channel for queuing circuits to be processed by the workers
let mut pq =
let (tx_work, rx_work) =
HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads);
// channel for sending circuits from threads back to main
let (tx_result, rx_result) = crossbeam_channel::unbounded();

let initial_circ_hash = circ.circuit_hash();
let mut best_circ = circ.clone();
let mut best_circ_cost = (self.cost)(&best_circ);
logger.log_best(best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]);

// Each worker waits for circuits to scan for rewrites using all the
// patterns and sends the results back to main.
let joins: Vec<_> = (0..n_threads)
.map(|i| {
TasoWorker::spawn(
pq.pop.clone().unwrap(),
pq.push.clone().unwrap(),
rx_work.clone(),
tx_result.clone(),
self.rewriter.clone(),
self.strategy.clone(),
Some(format!("taso-worker-{i}")),
)
})
.collect();
// Drop our copy of the worker channels, so we don't count as a
// connected worker.
drop(rx_work);
drop(tx_result);

// Queue the initial circuit
pq.push
.as_ref()
.unwrap()
tx_work
.send(vec![(initial_circ_hash, circ.clone())])
.unwrap();
// Drop our copy of the priority queue channels, so we don't count as a
// connected worker.
pq.drop_pop_push();

// A counter of circuits seen.
let mut circ_cnt = 1;

// A counter of jobs sent to the workers.
#[allow(unused)]
let mut jobs_sent = 0usize;
// A counter of completed jobs received from the workers.
#[allow(unused)]
let mut jobs_completed = 0usize;
// TODO: Report dropped jobs in the queue, so we can check for termination.

// Deadline for the optimisation timeout
Expand All @@ -219,51 +225,66 @@ where
Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)),
};

// Main loop: log best circuits as they come in from the priority queue,
// until the timeout is reached.
// Process worker results until we have seen all the circuits, or we run
// out of time.
let mut timeout_flag = false;
loop {
select! {
recv(pq.log) -> msg => {
recv(rx_result) -> msg => {
match msg {
Ok(PriorityChannelLog::NewBestCircuit(circ, cost)) => {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(best_circ_cost);
Ok(hashed_circs) => {
let send_result = tracing::trace_span!(target: "taso::metrics", "recv_result").in_scope(|| {
jobs_completed += 1;
for (circ_hash, circ) in &hashed_circs {
circ_cnt += 1;
logger.log_progress(circ_cnt, None, seen_hashes.len());
if seen_hashes.contains(circ_hash) {
continue;
}
seen_hashes.insert(*circ_hash);

let cost = (self.cost)(circ);

// Check if we got a new best circuit
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost;
logger.log_best(best_circ_cost);
}
jobs_sent += 1;
}
// Fill the workqueue with data from pq
tx_work.send(hashed_circs)
});
if send_result.is_err() {
eprintln!("All our workers panicked. Stopping optimisation.");
break;
}

// If there is no more data to process, we are done.
//
// TODO: Report dropped jobs in the workers, so we can check for termination.
//if jobs_sent == jobs_completed {
// break 'main;
//};
},
Ok(PriorityChannelLog::CircuitCount(circuit_cnt, seen_cnt)) => {
logger.log_progress(circuit_cnt, None, seen_cnt);
}
Err(crossbeam_channel::RecvError) => {
eprintln!("Priority queue panicked. Stopping optimisation.");
eprintln!("All our workers panicked. Stopping optimisation.");
break;
}
}
}
recv(timeout_event) -> _ => {
timeout_flag = true;
pq.timeout();
break;
}
}
}

// Empty the log from the priority queue and store final circuit count.
let mut circuit_cnt = None;
while let Ok(log) = pq.log.recv() {
match log {
PriorityChannelLog::NewBestCircuit(circ, cost) => {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(best_circ_cost);
}
PriorityChannelLog::CircuitCount(circ_cnt, _) => {
circuit_cnt = Some(circ_cnt);
}
}
}
logger.log_processing_end(circuit_cnt.unwrap_or(0), best_circ_cost, true, timeout_flag);
logger.log_processing_end(circ_cnt, best_circ_cost, true, timeout_flag);

// Drop the channel so the threads know to stop.
drop(tx_work);
joins.into_iter().for_each(|j| j.join().unwrap());

best_circ
Expand Down
155 changes: 0 additions & 155 deletions src/optimiser/taso/hugr_hash_set.rs

This file was deleted.

Loading

0 comments on commit 387cae0

Please sign in to comment.