Skip to content

Commit

Permalink
feat: queue size argument for TASO
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Oct 5, 2023
1 parent f42746f commit 4e5ac31
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 33 deletions.
2 changes: 2 additions & 0 deletions pyrs/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ impl PyDefaultTasoOptimiser {
n_threads: Option<NonZeroUsize>,
split_circ: Option<bool>,
log_progress: Option<PathBuf>,
queue_size: Option<usize>,
) -> PyResult<PyObject> {
let taso_logger = log_progress
.map(|file_name| {
Expand All @@ -78,6 +79,7 @@ impl PyDefaultTasoOptimiser {
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
split_circ.unwrap_or(false),
queue_size.unwrap_or(10_000),
)
})
}
Expand Down
48 changes: 34 additions & 14 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,16 @@ where
timeout: Option<u64>,
n_threads: NonZeroUsize,
split_circuit: bool,
queue_size: usize,
) -> Hugr {
self.optimise_with_log(circ, Default::default(), timeout, n_threads, split_circuit)
self.optimise_with_log(
circ,
Default::default(),
timeout,
n_threads,
split_circuit,
queue_size,
)
}

/// Run the TASO optimiser on a circuit with logging activated.
Expand All @@ -105,20 +113,27 @@ where
timeout: Option<u64>,
n_threads: NonZeroUsize,
split_circuit: bool,
queue_size: usize,
) -> Hugr {
if split_circuit && n_threads.get() > 1 {
return self
.split_run(circ, log_config, timeout, n_threads)
.split_run(circ, log_config, timeout, n_threads, queue_size)
.unwrap();
}
match n_threads.get() {
1 => self.taso(circ, log_config, timeout),
_ => self.taso_multithreaded(circ, log_config, timeout, n_threads),
1 => self.taso(circ, log_config, timeout, queue_size),
_ => self.taso_multithreaded(circ, log_config, timeout, n_threads, queue_size),
}
}

#[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))]
fn taso(&self, circ: &Hugr, mut logger: TasoLogger, timeout: Option<u64>) -> Hugr {
fn taso(
&self,
circ: &Hugr,
mut logger: TasoLogger,
timeout: Option<u64>,
queue_size: usize,
) -> Hugr {
let start_time = Instant::now();

let mut best_circ = circ.clone();
Expand All @@ -130,12 +145,11 @@ where
seen_hashes.insert(circ.circuit_hash());

// The priority queue of circuits to be processed (this should not get big)
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;
let cost_fn = {
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let mut pq = HugrPQ::with_capacity(cost_fn, PRIORITY_QUEUE_CAPACITY);
let mut pq = HugrPQ::with_capacity(cost_fn, queue_size);
pq.push(circ.clone());

let mut circ_cnt = 1;
Expand All @@ -160,9 +174,9 @@ where
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);
}

if pq.len() >= PRIORITY_QUEUE_CAPACITY {
if pq.len() >= queue_size {
// Haircut to keep the queue size manageable
pq.truncate(PRIORITY_QUEUE_CAPACITY / 2);
pq.truncate(queue_size / 2);
}

if let Some(timeout) = timeout {
Expand All @@ -188,16 +202,16 @@ where
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
queue_size: usize,
) -> Hugr {
let n_threads: usize = n_threads.get();
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

// multi-consumer priority channel for queuing circuits to be processed by the workers
let cost_fn = {
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let mut pq = HugrPriorityChannel::init(cost_fn, PRIORITY_QUEUE_CAPACITY * n_threads);
let mut pq = HugrPriorityChannel::init(cost_fn, queue_size);

let initial_circ_hash = circ.circuit_hash();
let mut best_circ = circ.clone();
Expand Down Expand Up @@ -293,6 +307,7 @@ where
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
queue_size: usize,
) -> Result<Hugr, HugrError> {
let circ_cost = self.cost(circ);
let max_chunk_cost = circ_cost.clone().div_cost(n_threads);
Expand All @@ -317,8 +332,13 @@ where
let join = thread::Builder::new()
.name(format!("chunk-{}", i))
.spawn(move || {
let res =
taso.optimise(&chunk, timeout, NonZeroUsize::new(1).unwrap(), false);
let res = taso.optimise(
&chunk,
timeout,
NonZeroUsize::new(1).unwrap(),
false,
queue_size,
);
tx.send(res).unwrap();
})
.unwrap();
Expand Down Expand Up @@ -428,7 +448,7 @@ mod tests {

#[rstest]
fn rz_rz_cancellation(rz_rz: Hugr, taso_opt: DefaultTasoOptimiser) {
let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false);
let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false, 10_000);
let cmds = opt_rz
.commands()
.map(|cmd| {
Expand Down
10 changes: 5 additions & 5 deletions src/optimiser/taso/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ impl<'w> TasoLogger<'w> {
needs_joining: bool,
timeout: bool,
) {
if timeout {
self.log("Timeout");
}
self.log("Optimisation finished");
match timeout {
true => self.log("Optimisation finished (timeout)"),
false => self.log("Optimisation finished"),
};
self.log(format!("Tried {circuit_count} circuits"));
self.log(format!("END RESULT: {:?}", best_cost));
self.log(format!("---- END RESULT: {:?} ----", best_cost));
if needs_joining {
self.log("Joining worker threads");
}
Expand Down
43 changes: 29 additions & 14 deletions taso-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,41 @@ struct CmdLineArgs {
help = "The number of threads to use. By default, use a single thread."
)]
n_threads: Option<NonZeroUsize>,
/// Number of threads (default=1)
/// Split the circuit into chunks, and process them separately.
#[arg(
long = "split-circ",
help = "Split the circuit into chunks and optimize each one in a separate thread. Use `-j` to specify the number of threads to use."
)]
split_circ: bool,
/// Max queue size.
#[arg(
short = 'q',
long = "queue-size",
default_value = "10000",
value_name = "QUEUE_SIZE",
help = "The priority queue size. Defaults to 10_000."
)]
queue_size: usize,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
let opts = CmdLineArgs::parse();

// Setup tracing subscribers for stdout and file logging.
//
// We need to keep the object around to keep the logging active.
let _tracer = Tracer::setup_tracing(opts.logfile, opts.split_circ);

let input_path = Path::new(&opts.input);
let output_path = Path::new(&opts.output);
let ecc_path = Path::new(&opts.eccs);

let n_threads = opts
.n_threads
// TODO: Default to multithreading once that produces better results.
//.or_else(|| std::thread::available_parallelism().ok())
.unwrap_or(NonZeroUsize::new(1).unwrap());

// Setup tracing subscribers for stdout and file logging.
//
// We need to keep the object around to keep the logging active.
let _tracer = Tracer::setup_tracing(opts.logfile, n_threads.get() > 1);

// TODO: Remove this from the Logger, and use tracing events instead.
let circ_candidates_csv = BufWriter::new(File::create("best_circs.csv")?);

Expand All @@ -113,21 +128,21 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
);
exit(1);
};

let n_threads = opts
.n_threads
// TODO: Default to multithreading once that produces better results.
//.or_else(|| std::thread::available_parallelism().ok())
.unwrap_or(NonZeroUsize::new(1).unwrap());
println!("Using {n_threads} threads");

if opts.split_circ && n_threads.get() > 1 {
println!("Splitting circuit into {n_threads} chunks.");
}

println!("Optimising...");
let opt_circ =
optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads, opts.split_circ);
let opt_circ = optimiser.optimise_with_log(
&circ,
taso_logger,
opts.timeout,
n_threads,
opts.split_circ,
opts.queue_size,
);

println!("Saving result");
save_tk1_json_file(&opt_circ, output_path)?;
Expand Down

0 comments on commit 4e5ac31

Please sign in to comment.