Skip to content

Commit

Permalink
Merge branch 'main' into feat/json-angleadd
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada authored Oct 4, 2023
2 parents c53cee3 + 7099098 commit d7e1d74
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 150 deletions.
9 changes: 6 additions & 3 deletions src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,24 @@ pub fn load_tk1_json_str(json: &str) -> Result<Hugr, TK1ConvertError> {
}

/// Save a circuit to file in TK1 JSON format.
pub fn save_tk1_json_file(path: impl AsRef<Path>, circ: &Hugr) -> Result<(), TK1ConvertError> {
pub fn save_tk1_json_file(
circ: &impl Circuit,
path: impl AsRef<Path>,
) -> Result<(), TK1ConvertError> {
let file = fs::File::create(path)?;
let writer = io::BufWriter::new(file);
save_tk1_json_writer(circ, writer)
}

/// Save a circuit in TK1 JSON format to a writer.
pub fn save_tk1_json_writer(circ: &Hugr, w: impl io::Write) -> Result<(), TK1ConvertError> {
pub fn save_tk1_json_writer(circ: &impl Circuit, w: impl io::Write) -> Result<(), TK1ConvertError> {
let serial_circ = SerialCircuit::encode(circ)?;
serde_json::to_writer(w, &serial_circ)?;
Ok(())
}

/// Save a circuit in TK1 JSON format to a String.
pub fn save_tk1_json_str(circ: &Hugr) -> Result<String, TK1ConvertError> {
pub fn save_tk1_json_str(circ: &impl Circuit) -> Result<String, TK1ConvertError> {
let mut buf = io::BufWriter::new(Vec::new());
save_tk1_json_writer(circ, &mut buf)?;
let bytes = buf.into_inner().unwrap();
Expand Down
135 changes: 58 additions & 77 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ mod worker;

use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
use fxhash::FxHashSet;
pub use log::TasoLogger;

use std::fmt;
use std::num::NonZeroUsize;
use std::time::{Duration, Instant};

use fxhash::FxHashSet;
use hugr::Hugr;

use crate::circuit::CircuitHash;
use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel;
use crate::optimiser::taso::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog};
use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::taso::worker::TasoWorker;
use crate::rewrite::strategy::RewriteStrategy;
Expand Down Expand Up @@ -81,7 +81,10 @@ where
/// Run the TASO optimiser on a circuit.
///
/// A timeout (in seconds) can be provided.
pub fn optimise(&self, circ: &Hugr, timeout: Option<u64>, n_threads: NonZeroUsize) -> Hugr {
pub fn optimise(&self, circ: &Hugr, timeout: Option<u64>, n_threads: NonZeroUsize) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
self.optimise_with_log(circ, Default::default(), timeout, n_threads)
}

Expand All @@ -94,7 +97,10 @@ where
log_config: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Hugr {
) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
match n_threads.get() {
1 => self.taso(circ, log_config, timeout),
_ => self.taso_multithreaded(circ, log_config, timeout, n_threads),
Expand All @@ -110,7 +116,8 @@ where
logger.log_best(&best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]);
let mut seen_hashes = FxHashSet::default();
seen_hashes.insert(circ.circuit_hash());

// The priority queue of circuits to be processed (this should not get big)
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;
Expand All @@ -133,13 +140,14 @@ where
let rewrites = self.rewriter.get_rewrites(&circ);
for new_circ in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_hash = new_circ.circuit_hash();
circ_cnt += 1;
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
if seen_hashes.contains(&new_circ_hash) {
if !seen_hashes.insert(new_circ_hash) {
// Ignore this circuit: we've already seen it
continue;
}
pq.push_with_hash_unchecked(new_circ, new_circ_hash);
seen_hashes.insert(new_circ_hash);
circ_cnt += 1;
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
let new_circ_cost = self.cost(&new_circ);
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);
}

if pq.len() >= PRIORITY_QUEUE_CAPACITY {
Expand Down Expand Up @@ -170,7 +178,10 @@ where
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Hugr {
) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
let n_threads: usize = n_threads.get();
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

Expand All @@ -179,51 +190,36 @@ where
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let (tx_work, rx_work) =
HugrPriorityChannel::init(cost_fn, PRIORITY_QUEUE_CAPACITY * n_threads);
// channel for sending circuits from threads back to main
let (tx_result, rx_result) = crossbeam_channel::unbounded();
let mut pq = HugrPriorityChannel::init(cost_fn, PRIORITY_QUEUE_CAPACITY * n_threads);

let initial_circ_hash = circ.circuit_hash();
let mut best_circ = circ.clone();
let mut best_circ_cost = self.cost(&best_circ);
logger.log_best(&best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]);

// Each worker waits for circuits to scan for rewrites using all the
// patterns and sends the results back to main.
let joins: Vec<_> = (0..n_threads)
.map(|i| {
TasoWorker::spawn(
rx_work.clone(),
tx_result.clone(),
pq.pop.clone().unwrap(),
pq.push.clone().unwrap(),
self.rewriter.clone(),
self.strategy.clone(),
Some(format!("taso-worker-{i}")),
)
})
.collect();
// Drop our copy of the worker channels, so we don't count as a
// connected worker.
drop(rx_work);
drop(tx_result);

// Queue the initial circuit
tx_work
pq.push
.as_ref()
.unwrap()
.send(vec![(initial_circ_hash, circ.clone())])
.unwrap();
// Drop our copy of the priority queue channels, so we don't count as a
// connected worker.
pq.drop_pop_push();

// A counter of circuits seen.
let mut circ_cnt = 1;

// A counter of jobs sent to the workers.
#[allow(unused)]
let mut jobs_sent = 0usize;
// A counter of completed jobs received from the workers.
#[allow(unused)]
let mut jobs_completed = 0usize;
// TODO: Report dropped jobs in the queue, so we can check for termination.

// Deadline for the optimisation timeout
Expand All @@ -232,66 +228,51 @@ where
Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)),
};

// Process worker results until we have seen all the circuits, or we run
// out of time.
// Main loop: log best circuits as they come in from the priority queue,
// until the timeout is reached.
let mut timeout_flag = false;
loop {
select! {
recv(rx_result) -> msg => {
recv(pq.log) -> msg => {
match msg {
Ok(hashed_circs) => {
let send_result = tracing::trace_span!(target: "taso::metrics", "recv_result").in_scope(|| {
jobs_completed += 1;
for (circ_hash, circ) in &hashed_circs {
circ_cnt += 1;
logger.log_progress(circ_cnt, None, seen_hashes.len());
if seen_hashes.contains(circ_hash) {
continue;
}
seen_hashes.insert(*circ_hash);

let cost = self.cost(circ);

// Check if we got a new best circuit
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
}
jobs_sent += 1;
}
// Fill the workqueue with data from pq
tx_work.send(hashed_circs)
});
if send_result.is_err() {
eprintln!("All our workers panicked. Stopping optimisation.");
break;
}

// If there is no more data to process, we are done.
//
// TODO: Report dropped jobs in the workers, so we can check for termination.
//if jobs_sent == jobs_completed {
// break 'main;
//};
Ok(PriorityChannelLog::NewBestCircuit(circ, cost)) => {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
},
Ok(PriorityChannelLog::CircuitCount(circuit_cnt, seen_cnt)) => {
logger.log_progress(circuit_cnt, None, seen_cnt);
}
Err(crossbeam_channel::RecvError) => {
eprintln!("All our workers panicked. Stopping optimisation.");
eprintln!("Priority queue panicked. Stopping optimisation.");
break;
}
}
}
recv(timeout_event) -> _ => {
timeout_flag = true;
pq.timeout();
break;
}
}
}

logger.log_processing_end(circ_cnt, best_circ_cost, true, timeout_flag);
// Empty the log from the priority queue and store final circuit count.
let mut circuit_cnt = None;
while let Ok(log) = pq.log.recv() {
match log {
PriorityChannelLog::NewBestCircuit(circ, cost) => {
best_circ = circ;
best_circ_cost = cost;
logger.log_best(&best_circ_cost);
}
PriorityChannelLog::CircuitCount(circ_cnt, _) => {
circuit_cnt = Some(circ_cnt);
}
}
}
logger.log_processing_end(circuit_cnt.unwrap_or(0), best_circ_cost, true, timeout_flag);

// Drop the channel so the threads know to stop.
drop(tx_work);
joins.into_iter().for_each(|j| j.join().unwrap());

best_circ
Expand Down
Loading

0 comments on commit d7e1d74

Please sign in to comment.