diff --git a/Cargo.toml b/Cargo.toml index f582ded7..6161abcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ chrono = { version ="0.4.30" } bytemuck = "1.14.0" stringreader = "0.1.1" crossbeam-channel = "0.5.8" +tracing = { workspace = true } [features] pyo3 = [ @@ -79,3 +80,4 @@ itertools = { version = "0.11.0" } tket-json-rs = { git = "https://github.com/CQCL/tket-json-rs", rev = "619db15d3", features = [ "tket2ops", ] } +tracing = "0.1.37" diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 6a48d942..66c91ab6 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -14,54 +14,28 @@ mod eq_circ_class; mod hugr_pchannel; mod hugr_pqueue; +pub mod log; mod qtz_circuit; mod worker; use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; +use std::io; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; -use std::{fs, io}; use fxhash::FxHashSet; use hugr::Hugr; use crate::circuit::CircuitHash; -use crate::json::save_tk1_json_writer; +use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel; +use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ}; +use crate::optimiser::taso::worker::TasoWorker; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; -use hugr_pqueue::{Entry, HugrPQ}; -use self::hugr_pchannel::HugrPriorityChannel; - -/// Logging configuration for the TASO optimiser. -#[derive(Default)] -pub struct LogConfig<'w> { - final_circ_json: Option>, - circ_candidates_csv: Option>, - progress_log: Option>, -} - -impl<'w> LogConfig<'w> { - /// Create a new logging configuration. - /// - /// Three writer objects must be provided: - /// - best_circ_json: for the final optimised circuit, in TK1 JSON format, - /// - circ_candidates_csv: for a log of the successive best candidate circuits, - /// - progress_log: for a log of the progress of the optimisation. - pub fn new( - best_circ_json: impl io::Write + 'w, - circ_candidates_csv: impl io::Write + 'w, - progress_log: impl io::Write + 'w, - ) -> Self { - Self { - final_circ_json: Some(Box::new(best_circ_json)), - circ_candidates_csv: Some(Box::new(circ_candidates_csv)), - progress_log: Some(Box::new(progress_log)), - } - } -} +use self::log::TasoLogger; /// The TASO optimiser. /// @@ -116,7 +90,7 @@ where pub fn optimise_with_log( &self, circ: &Hugr, - log_config: LogConfig, + log_config: TasoLogger, timeout: Option, n_threads: NonZeroUsize, ) -> Hugr { @@ -126,37 +100,13 @@ where } } - /// Run the TASO optimiser on a circuit with default logging. - /// - /// The following files will be created: - /// - `final_circ.json`: the final optimised circuit, in TK1 JSON format, - /// - `best_circs.csv`: a log of the successive best candidate circuits, - /// - `taso-optimisation.log`: a log of the progress of the optimisation. - /// - /// If the creation of any of these files fails, an error is returned. - /// - /// A timeout (in seconds) can be provided. - pub fn optimise_with_default_log( - &self, - circ: &Hugr, - timeout: Option, - n_threads: NonZeroUsize, - ) -> io::Result { - let final_circ_json = fs::File::create("final_circ.json")?; - let circ_candidates_csv = fs::File::create("best_circs.csv")?; - let progress_log = fs::File::create("taso-optimisation.log")?; - let log_config = LogConfig::new(final_circ_json, circ_candidates_csv, progress_log); - Ok(self.optimise_with_log(circ, log_config, timeout, n_threads)) - } - - fn taso(&self, circ: &Hugr, mut log_config: LogConfig, timeout: Option) -> Hugr { + #[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))] + fn taso(&self, circ: &Hugr, mut logger: TasoLogger, timeout: Option) -> Hugr { let start_time = Instant::now(); - let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - let mut best_circ = circ.clone(); let mut best_circ_cost = (self.cost)(circ); - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + logger.log_best(best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); @@ -167,26 +117,19 @@ where pq.push(circ.clone()); let mut circ_cnt = 1; + let mut timeout_flag = false; while let Some(Entry { circ, cost, .. }) = pq.pop() { if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost; - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + logger.log_best(best_circ_cost); } let rewrites = self.rewriter.get_rewrites(&circ); for new_circ in self.strategy.apply_rewrites(rewrites, &circ) { let new_circ_hash = new_circ.circuit_hash(); circ_cnt += 1; - if circ_cnt % 1000 == 0 { - log_progress( - log_config.progress_log.as_mut(), - circ_cnt, - Some(&pq), - &seen_hashes, - ) - .expect("Failed to write to progress log"); - } + logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len()); if seen_hashes.contains(&new_circ_hash) { continue; } @@ -201,22 +144,13 @@ where if let Some(timeout) = timeout { if start_time.elapsed().as_secs() > timeout { - println!("Timeout"); + timeout_flag = true; break; } } } - log_processing_end(circ_cnt, false); - - log_final( - &best_circ, - log_config.progress_log.as_mut(), - log_config.final_circ_json.as_mut(), - &self.cost, - ) - .expect("Failed to write to progress log and/or final circuit JSON"); - + logger.log_processing_end(circ_cnt, best_circ_cost, false, timeout_flag); best_circ } @@ -224,18 +158,17 @@ where /// /// This is the multi-threaded version of [`taso`]. See [`TasoOptimiser`] for /// more details. + #[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))] fn taso_multithreaded( &self, circ: &Hugr, - mut log_config: LogConfig, + mut logger: TasoLogger, timeout: Option, n_threads: NonZeroUsize, ) -> Hugr { let n_threads: usize = n_threads.get(); const PRIORITY_QUEUE_CAPACITY: usize = 10_000; - let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - // multi-consumer priority channel for queuing circuits to be processed by the workers let (tx_work, rx_work) = HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads); @@ -245,7 +178,7 @@ where let initial_circ_hash = circ.circuit_hash(); let mut best_circ = circ.clone(); let mut best_circ_cost = (self.cost)(&best_circ); - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + logger.log_best(best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]); @@ -253,12 +186,13 @@ where // Each worker waits for circuits to scan for rewrites using all the // patterns and sends the results back to main. let joins: Vec<_> = (0..n_threads) - .map(|_| { - worker::spawn_pattern_matching_thread( + .map(|i| { + TasoWorker::spawn( rx_work.clone(), tx_result.clone(), self.rewriter.clone(), self.strategy.clone(), + Some(format!("taso-worker-{i}")), ) }) .collect(); @@ -283,7 +217,7 @@ where let mut jobs_completed = 0usize; // TODO: Report dropped jobs in the queue, so we can check for termination. - // Deadline for the optimization timeout + // Deadline for the optimisation timeout let timeout_event = match timeout { None => crossbeam_channel::never(), Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)), @@ -291,38 +225,39 @@ where // Process worker results until we have seen all the circuits, or we run // out of time. + let mut timeout_flag = false; loop { select! { recv(rx_result) -> msg => { match msg { Ok(hashed_circs) => { - jobs_completed += 1; - for (circ_hash, circ) in &hashed_circs { - circ_cnt += 1; - if circ_cnt % 1000 == 0 { - // TODO: Add a minimum time between logs - log_progress::<_,u64,usize>(log_config.progress_log.as_mut(), circ_cnt, None, &seen_hashes) - .expect("Failed to write to progress log"); - } - if !seen_hashes.insert(*circ_hash) { - continue; + let send_result = tracing::trace_span!(target: "taso::metrics", "recv_result").in_scope(|| { + jobs_completed += 1; + for (circ_hash, circ) in &hashed_circs { + circ_cnt += 1; + logger.log_progress(circ_cnt, None, seen_hashes.len()); + if seen_hashes.contains(circ_hash) { + continue; + } + seen_hashes.insert(*circ_hash); + + let cost = (self.cost)(circ); + + // Check if we got a new best circuit + if cost < best_circ_cost { + best_circ = circ.clone(); + best_circ_cost = cost; + logger.log_best(best_circ_cost); + } + jobs_sent += 1; } - - let cost = (self.cost)(circ); - - // Check if we got a new best circuit - if cost < best_circ_cost { - best_circ = circ.clone(); - best_circ_cost = cost; - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - } - jobs_sent += 1; - } - // Fill the workqueue with data from pq - if tx_work.send(hashed_circs).is_err() { + // Fill the workqueue with data from pq + tx_work.send(hashed_circs) + }); + if send_result.is_err() { eprintln!("All our workers panicked. Stopping optimisation."); break; - }; + } // If there is no more data to process, we are done. // @@ -338,25 +273,17 @@ where } } recv(timeout_event) -> _ => { - println!("Timeout"); + timeout_flag = true; break; } } } - log_processing_end(circ_cnt, true); + logger.log_processing_end(circ_cnt, best_circ_cost, true, timeout_flag); // Drop the channel so the threads know to stop. drop(tx_work); - let _ = joins; // joins.into_iter().for_each(|j| j.join().unwrap()); - - log_final( - &best_circ, - log_config.progress_log.as_mut(), - log_config.final_circ_json.as_mut(), - &self.cost, - ) - .expect("Failed to write to progress log and/or final circuit JSON"); + joins.into_iter().for_each(|j| j.join().unwrap()); best_circ } @@ -381,68 +308,3 @@ mod taso_default { } } } - -/// A helper struct for logging improvements in circuit size seen during the -/// TASO execution. -// -// TODO: Replace this fixed logging. Report back intermediate results. -#[derive(serde::Serialize, Clone, Debug)] -struct BestCircSer { - circ_len: usize, - time: String, -} - -impl BestCircSer { - fn new(circ_len: usize) -> Self { - let time = chrono::Local::now().to_rfc3339(); - Self { circ_len, time } - } -} - -fn log_best(cbest: usize, wtr: Option<&mut csv::Writer>) -> io::Result<()> { - let Some(wtr) = wtr else { - return Ok(()); - }; - println!("new best of size {}", cbest); - wtr.serialize(BestCircSer::new(cbest)).unwrap(); - wtr.flush() -} - -fn log_processing_end(circuit_count: usize, needs_joining: bool) { - println!("END"); - println!("Tried {circuit_count} circuits"); - if needs_joining { - println!("Joining"); - } -} - -fn log_progress( - wr: Option<&mut W>, - circ_cnt: usize, - pq: Option<&HugrPQ>, - seen_hashes: &FxHashSet, -) -> io::Result<()> { - if let Some(wr) = wr { - writeln!(wr, "{circ_cnt} circuits...")?; - if let Some(pq) = pq { - writeln!(wr, "Queue size: {} circuits", pq.len())?; - } - writeln!(wr, "Total seen: {} circuits", seen_hashes.len())?; - } - Ok(()) -} - -fn log_final( - best_circ: &Hugr, - log: Option<&mut W1>, - final_circ: Option<&mut W2>, - cost: impl Fn(&Hugr) -> usize, -) -> io::Result<()> { - if let Some(log) = log { - writeln!(log, "END RESULT: {}", cost(best_circ))?; - } - if let Some(circ_writer) = final_circ { - save_tk1_json_writer(best_circ, circ_writer).unwrap(); - } - Ok(()) -} diff --git a/src/optimiser/taso/hugr_pchannel.rs b/src/optimiser/taso/hugr_pchannel.rs index c3e157cf..1ec0d2e4 100644 --- a/src/optimiser/taso/hugr_pchannel.rs +++ b/src/optimiser/taso/hugr_pchannel.rs @@ -45,6 +45,7 @@ where let in_channel = in_channel_orig.clone(); let out_channel = out_channel_orig.clone(); let _ = builder + .name("priority-channel".into()) .spawn(move || { // The priority queue, local to this thread. let mut pq: HugrPQ = diff --git a/src/optimiser/taso/log.rs b/src/optimiser/taso/log.rs new file mode 100644 index 00000000..39969ee6 --- /dev/null +++ b/src/optimiser/taso/log.rs @@ -0,0 +1,111 @@ +//! Logging utilities for the TASO optimiser. + +use std::io; + +/// Logging configuration for the TASO optimiser. +#[derive(Default)] +pub struct TasoLogger<'w> { + circ_candidates_csv: Option>>, +} + +/// The logging target for general events. +pub const LOG_TARGET: &str = "taso::log"; +/// The logging target for progress events. More verbose than the general log. +pub const PROGRESS_TARGET: &str = "taso::progress"; +/// The logging target for function spans. +pub const METRICS_TARGET: &str = "taso::metrics"; + +impl<'w> TasoLogger<'w> { + /// Create a new logging configuration. + /// + /// Two writer objects must be provided: + /// - best_progress_csv_writer: for a log of the successive best candidate + /// circuits, + /// + /// Regular events are logged with [`tracing`], with targets [`LOG_TARGET`] + /// or [`PROGRESS_TARGET`]. + /// + /// [`log`]: + pub fn new(best_progress_csv_writer: impl io::Write + 'w) -> Self { + let boxed_candidates_writer: Box = Box::new(best_progress_csv_writer); + Self { + circ_candidates_csv: Some(csv::Writer::from_writer(boxed_candidates_writer)), + } + } + + /// Log a new best candidate + #[inline] + pub fn log_best(&mut self, best_cost: usize) { + self.log(format!("new best of size {}", best_cost)); + if let Some(csv_writer) = self.circ_candidates_csv.as_mut() { + csv_writer.serialize(BestCircSer::new(best_cost)).unwrap(); + csv_writer.flush().unwrap(); + }; + } + + /// Log the final optimised circuit + #[inline] + pub fn log_processing_end( + &self, + circuit_count: usize, + best_cost: usize, + needs_joining: bool, + timeout: bool, + ) { + if timeout { + self.log("Timeout"); + } + self.log("Optimisation finished"); + self.log(format!("Tried {circuit_count} circuits")); + self.log(format!("END RESULT: {}", best_cost)); + if needs_joining { + self.log("Joining worker threads"); + } + } + + /// Log the progress of the optimisation. + #[inline(always)] + pub fn log_progress( + &mut self, + circ_cnt: usize, + workqueue_len: Option, + seen_hashes: usize, + ) { + if circ_cnt % 1000 == 0 { + self.progress(format!("{circ_cnt} circuits...")); + if let Some(workqueue_len) = workqueue_len { + self.progress(format!("Queue size: {workqueue_len} circuits")); + } + self.progress(format!("Total seen: {} circuits", seen_hashes)); + } + } + + /// Internal function to log general events, normally printed to stdout. + #[inline] + fn log(&self, msg: impl AsRef) { + tracing::info!(target: LOG_TARGET, "{}", msg.as_ref()); + } + + /// Internal function to log information on the progress of the optimization. + #[inline] + fn progress(&self, msg: impl AsRef) { + tracing::info!(target: PROGRESS_TARGET, "{}", msg.as_ref()); + } +} + +/// A helper struct for logging improvements in circuit size seen during the +/// TASO execution. +// +// TODO: Replace this fixed logging. Report back intermediate results. +#[derive(serde::Serialize, Clone, Debug)] +struct BestCircSer { + circ_len: usize, + time: String, +} + +impl BestCircSer { + fn new(circ_len: usize) -> Self { + let time = chrono::Local::now().to_rfc3339(); + Self { circ_len, time } + } +} diff --git a/src/optimiser/taso/worker.rs b/src/optimiser/taso/worker.rs index 4472031e..0d8667c8 100644 --- a/src/optimiser/taso/worker.rs +++ b/src/optimiser/taso/worker.rs @@ -8,24 +8,60 @@ use crate::circuit::CircuitHash; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; -pub fn spawn_pattern_matching_thread( - rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, - tx_result: crossbeam_channel::Sender>, - rewriter: impl Rewriter + Send + 'static, - strategy: impl RewriteStrategy + Send + 'static, -) -> JoinHandle<()> { - thread::spawn(move || { - // Process work until the main thread closes the channel send or receive - // channel. +/// A worker that processes circuits for the TASO optimiser. +pub struct TasoWorker { + _phantom: std::marker::PhantomData<(R, S)>, +} + +impl TasoWorker +where + R: Rewriter + Send + 'static, + S: RewriteStrategy + Send + 'static, +{ + /// Spawn a new worker thread. + pub fn spawn( + rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, + tx_result: crossbeam_channel::Sender>, + rewriter: R, + strategy: S, + worker_name: Option, + ) -> JoinHandle<()> { + let mut builder = thread::Builder::new(); + if let Some(name) = worker_name { + builder = builder.name(name); + }; + builder + .spawn(move || Self::worker_loop(rx_work, tx_result, rewriter, strategy)) + .unwrap() + } + + /// Main loop of the worker. + /// + /// Processes work until the main thread closes the channel send or receive + /// channel. + #[tracing::instrument(target = "taso::metrics", skip_all)] + fn worker_loop( + rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, + tx_result: crossbeam_channel::Sender>, + rewriter: R, + strategy: S, + ) { while let Ok((_hash, circ)) = rx_work.recv() { - let rewrites = rewriter.get_rewrites(&circ); - let circs = strategy.apply_rewrites(rewrites, &circ); - let hashed_circs = circs.into_iter().map(|c| (c.circuit_hash(), c)).collect(); - let send = tx_result.send(hashed_circs); + let hashed_circs = Self::process_circ(circ, &rewriter, &strategy); + let send = tracing::trace_span!(target: "taso::metrics", "TasoWorker::send_result") + .in_scope(|| tx_result.send(hashed_circs)); if send.is_err() { // The main thread closed the send channel, we can stop. break; } } - }) + } + + /// Process a circuit. + #[tracing::instrument(target = "taso::metrics", skip_all)] + fn process_circ(circ: Hugr, rewriter: &R, strategy: &S) -> Vec<(u64, Hugr)> { + let rewrites = rewriter.get_rewrites(&circ); + let circs = strategy.apply_rewrites(rewrites, &circ); + circs.into_iter().map(|c| (c.circuit_hash(), c)).collect() + } } diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 645b5c16..29d5a934 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -45,6 +45,7 @@ pub trait RewriteStrategy { pub struct GreedyRewriteStrategy; impl RewriteStrategy for GreedyRewriteStrategy { + #[tracing::instrument(skip_all)] fn apply_rewrites( &self, rewrites: impl IntoIterator, @@ -99,6 +100,7 @@ impl Default for ExhaustiveRewriteStrategy { } impl RewriteStrategy for ExhaustiveRewriteStrategy { + #[tracing::instrument(skip_all)] fn apply_rewrites( &self, rewrites: impl IntoIterator, diff --git a/taso-optimiser/Cargo.toml b/taso-optimiser/Cargo.toml index 787f5654..d3106737 100644 --- a/taso-optimiser/Cargo.toml +++ b/taso-optimiser/Cargo.toml @@ -12,8 +12,11 @@ tket2 = { path = "../", features = ["portmatching"] } quantinuum-hugr = { workspace = true } itertools = { workspace = true } tket-json-rs = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = "0.3.17" +tracing-appender = "0.2.2" peak_alloc = { version = "0.2.0", optional = true } [features] default = ["peak_alloc"] -peak_alloc = ["dep:peak_alloc"] \ No newline at end of file +peak_alloc = ["dep:peak_alloc"] diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 4c277f8e..95368318 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -1,9 +1,16 @@ +mod tracing; + +use crate::tracing::Tracer; + +use std::io::BufWriter; use std::num::NonZeroUsize; +use std::path::PathBuf; use std::process::exit; -use std::{fs, io, path::Path}; +use std::{fs, path::Path}; use clap::Parser; use hugr::Hugr; +use tket2::optimiser::taso::log::TasoLogger; use tket2::{ json::{load_tk1_json_file, TKETDecode}, optimiser::TasoOptimiser, @@ -31,7 +38,7 @@ struct CmdLineArgs { value_name = "FILE", help = "Input. A quantum circuit in TK1 JSON format." )] - input: String, + input: PathBuf, /// Output circuit file #[arg( short, @@ -40,7 +47,7 @@ struct CmdLineArgs { value_name = "FILE", help = "Output. A quantum circuit in TK1 JSON format." )] - output: String, + output: PathBuf, /// ECC file #[arg( short, @@ -48,7 +55,16 @@ struct CmdLineArgs { value_name = "ECC_FILE", help = "Sets the ECC file to use. It is a JSON file of Quartz-generated ECCs." )] - eccs: String, + eccs: PathBuf, + /// Log output file + #[arg( + short, + long, + default_value = "taso-optimisation.log", + value_name = "LOGFILE", + help = "Logfile to to output the progress of the optimisation." + )] + logfile: Option, /// Timeout in seconds (default=no timeout) #[arg( short, @@ -62,27 +78,37 @@ struct CmdLineArgs { short = 'j', long, value_name = "N_THREADS", - help = "The number of threads to use. By default, the number of threads is equal to the number of logical cores." + help = "The number of threads to use. By default, use a single thread." )] n_threads: Option, } fn save_tk1_json_file(path: impl AsRef, circ: &Hugr) -> Result<(), std::io::Error> { let file = fs::File::create(path)?; - let writer = io::BufWriter::new(file); + let writer = BufWriter::new(file); let serial_circ = SerialCircuit::encode(circ).unwrap(); serde_json::to_writer_pretty(writer, &serial_circ)?; Ok(()) } -fn main() { +fn main() -> Result<(), Box> { let opts = CmdLineArgs::parse(); + // Setup tracing subscribers for stdout and file logging. + // + // We need to keep the object around to keep the logging active. + let _tracer = Tracer::setup_tracing(opts.logfile); + let input_path = Path::new(&opts.input); let output_path = Path::new(&opts.output); let ecc_path = Path::new(&opts.eccs); - let circ = load_tk1_json_file(input_path).unwrap(); + // TODO: Remove this from the Logger, and use tracing events instead. + let circ_candidates_csv = fs::File::create("best_circs.csv")?; + + let taso_logger = TasoLogger::new(circ_candidates_csv); + + let circ = load_tk1_json_file(input_path)?; println!("Compiling rewriter..."); let Ok(optimiser) = TasoOptimiser::default_with_eccs_json_file(ecc_path) else { @@ -101,15 +127,14 @@ fn main() { println!("Using {n_threads} threads"); println!("Optimising..."); - let opt_circ = optimiser - .optimise_with_default_log(&circ, opts.timeout, n_threads) - .unwrap(); + let opt_circ = optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads); println!("Saving result"); - save_tk1_json_file(output_path, &opt_circ).unwrap(); + save_tk1_json_file(output_path, &opt_circ)?; #[cfg(feature = "peak_alloc")] println!("Peak memory usage: {} GB", PEAK_ALLOC.peak_usage_as_gb()); - println!("Done.") + println!("Done."); + Ok(()) } diff --git a/taso-optimiser/src/tracing.rs b/taso-optimiser/src/tracing.rs new file mode 100644 index 00000000..0e90f0a5 --- /dev/null +++ b/taso-optimiser/src/tracing.rs @@ -0,0 +1,73 @@ +//! Setup routines for tracing and logging of the optimisation process. +use std::fs::File; +use std::io::BufWriter; +use std::path::PathBuf; + +use tket2::optimiser::taso::log::{LOG_TARGET, METRICS_TARGET, PROGRESS_TARGET}; + +use tracing::{Metadata, Subscriber}; +use tracing_appender::non_blocking; +use tracing_subscriber::filter::filter_fn; +use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::Layer; + +fn log_filter(metadata: &Metadata<'_>) -> bool { + metadata.target().starts_with(LOG_TARGET) +} + +fn verbose_filter(metadata: &Metadata<'_>) -> bool { + metadata.target().starts_with(LOG_TARGET) || metadata.target().starts_with(PROGRESS_TARGET) +} + +#[allow(unused)] +fn metrics_filter(metadata: &Metadata<'_>) -> bool { + metadata.target().starts_with(METRICS_TARGET) +} + +#[derive(Debug, Default)] +pub struct Tracer { + pub logfile: Option, +} + +impl Tracer { + /// Setup tracing subscribers for stdout and file logging. + pub fn setup_tracing(logfile: Option) -> Self { + let mut tracer = Self::default(); + tracing_subscriber::registry() + .with(tracer.stdout_layer()) + .with(logfile.map(|f| tracer.logfile_layer(f))) + .init(); + tracer + } + + /// Initialize a file logger handle and non-blocking worker. + fn init_writer(&self, file: PathBuf) -> (non_blocking::NonBlocking, non_blocking::WorkerGuard) { + let writer = BufWriter::new(File::create(file).unwrap()); + non_blocking(writer) + } + + /// Clean log with the most important events. + fn stdout_layer(&mut self) -> impl Layer + where + S: Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>, + { + tracing_subscriber::fmt::layer() + .without_time() + .with_target(false) + .with_level(false) + .with_filter(filter_fn(log_filter)) + } + + fn logfile_layer(&mut self, logfile: PathBuf) -> impl Layer + where + S: Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>, + { + let (non_blocking, guard) = self.init_writer(logfile); + self.logfile = Some(guard); + tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_writer(non_blocking) + .with_filter(filter_fn(verbose_filter)) + } +}