diff --git a/src/run.rs b/src/run.rs index 5dd80e4b..fbd37b30 100644 --- a/src/run.rs +++ b/src/run.rs @@ -151,13 +151,42 @@ pub struct Runner, IterData = ()> { #[allow(clippy::type_complexity)] pub hooks: Vec Result<(), String>>>, - // limits + limits: RunnerLimits, + scheduler: Box>, +} + +/// Describes the limits that would stop a [`Runner`]. +#[derive(Debug)] +pub struct RunnerLimits { iter_limit: usize, node_limit: usize, time_limit: Duration, - start_time: Option, - scheduler: Box>, +} + +impl RunnerLimits { + /// Check if the [`Runner`] should stop based on the limits. + pub fn check_limits(&self, iteration: usize, egraph: &EGraph) -> RunnerResult<()> + where + L: Language, + N: Analysis, + { + let elapsed = self.start_time.unwrap().elapsed(); + if elapsed > self.time_limit { + return Err(StopReason::TimeLimit(elapsed.as_secs_f64())); + } + + let size = egraph.total_size(); + if size > self.node_limit { + return Err(StopReason::NodeLimit(size)); + } + + if iteration >= self.iter_limit { + return Err(StopReason::IterationLimit(iteration)); + } + + Ok(()) + } } impl Default for Runner @@ -184,10 +213,7 @@ where roots, stop_reason, hooks, - iter_limit, - node_limit, - time_limit, - start_time, + limits, scheduler: _, } = self; @@ -197,10 +223,7 @@ where .field("roots", roots) .field("stop_reason", stop_reason) .field("hooks", &vec![format_args!(""); hooks.len()]) - .field("iter_limit", iter_limit) - .field("node_limit", node_limit) - .field("time_limit", time_limit) - .field("start_time", start_time) + .field("limits", limits) .field("scheduler", &format_args!("")) .finish() } @@ -303,7 +326,8 @@ pub struct Iteration { pub stop_reason: Option, } -type RunnerResult = std::result::Result; +/// Type alias for the result of a [`Runner`]. +pub type RunnerResult = std::result::Result; impl Runner where @@ -314,34 +338,37 @@ where /// Create a new `Runner` with the given analysis and default parameters. pub fn new(analysis: N) -> Self { Self { - iter_limit: 30, - node_limit: 10_000, - time_limit: Duration::from_secs(5), - + limits: RunnerLimits { + iter_limit: 30, + node_limit: 10_000, + time_limit: Duration::from_secs(5), + start_time: None, + }, egraph: EGraph::new(analysis), roots: vec![], iterations: vec![], stop_reason: None, hooks: vec![], - - start_time: None, scheduler: Box::new(BackoffScheduler::default()), } } /// Sets the iteration limit. Default: 30 - pub fn with_iter_limit(self, iter_limit: usize) -> Self { - Self { iter_limit, ..self } + pub fn with_iter_limit(mut self, iter_limit: usize) -> Self { + self.limits.iter_limit = iter_limit; + self } /// Sets the egraph size limit (in enodes). Default: 10,000 - pub fn with_node_limit(self, node_limit: usize) -> Self { - Self { node_limit, ..self } + pub fn with_node_limit(mut self, node_limit: usize) -> Self { + self.limits.node_limit = node_limit; + self } /// Sets the runner time limit. Default: 5 seconds - pub fn with_time_limit(self, time_limit: Duration) -> Self { - Self { time_limit, ..self } + pub fn with_time_limit(mut self, time_limit: Duration) -> Self { + self.limits.time_limit = time_limit; + self } /// Add a hook to instrument or modify the behavior of a [`Runner`]. @@ -525,11 +552,15 @@ where let mut matches = Vec::new(); let mut applied = IndexMap::default(); result = result.and_then(|_| { - rules.iter().try_for_each(|rw| { - let ms = self.scheduler.search_rewrite(i, &self.egraph, rw); - matches.push(ms); - self.check_limits() - }) + matches = self + .scheduler + .search_rewrites(i, &self.egraph, rules, &self.limits)?; + Ok(()) + // rules.iter().try_for_each(|rw| { + // let ms = self.scheduler.search_rewrite(i, &self.egraph, rw); + // matches.push(ms); + // self.check_limits() + // }) }); let search_time = start_time.elapsed().as_secs_f64(); @@ -602,25 +633,12 @@ where } fn try_start(&mut self) { - self.start_time.get_or_insert_with(Instant::now); + self.limits.start_time.get_or_insert_with(Instant::now); } fn check_limits(&self) -> RunnerResult<()> { - let elapsed = self.start_time.unwrap().elapsed(); - if elapsed > self.time_limit { - return Err(StopReason::TimeLimit(elapsed.as_secs_f64())); - } - - let size = self.egraph.total_size(); - if size > self.node_limit { - return Err(StopReason::NodeLimit(size)); - } - - if self.iterations.len() >= self.iter_limit { - return Err(StopReason::IterationLimit(self.iterations.len())); - } - - Ok(()) + self.limits + .check_limits(self.iterations.len(), &self.egraph) } } @@ -678,6 +696,57 @@ where rewrite.search(egraph) } + /// A hook allowing you to customize rewrite searching behavior + /// across rewrites. + /// + /// Default implementation calls + /// [`Self::search_rewrite`] for each rewrite, + /// and checks [`RunnerLimits::check_limits`] after each. + /// + /// Returning an error will stop the runner. + /// + /// You might use this to implement parallel rule application: + /// ``` + /// # use egg::*; + /// pub struct ParallelRewriteScheduler; + /// impl RewriteScheduler for ParallelRewriteScheduler { + /// fn search_rewrites<'a>( + /// &mut self, + /// iteration: usize, + /// egraph: &EGraph, + /// rewrites: &[&'a Rewrite], + /// _limits: &RunnerLimits, + /// ) -> RunnerResult>>> { + /// // this implementation just ignores the limits + /// // fake `par_map` to enforce Send + Sync, in real life use rayon + /// fn par_map(slice: &[T], f: F) -> Vec + /// where + /// T: Send + Sync, + /// F: Fn(&T) -> T2 + Send + Sync, + /// T2: Send + Sync, + /// { + /// slice.iter().map(f).collect() + /// } + /// Ok(par_map(rewrites, |rw| rw.search(egraph))) + /// } + /// } + /// ``` + fn search_rewrites<'a>( + &mut self, + iteration: usize, + egraph: &EGraph, + rewrites: &[&'a Rewrite], + limits: &RunnerLimits, + ) -> RunnerResult>>> { + let mut matches = Vec::new(); + for rw in rewrites { + let ms = self.search_rewrite(iteration, egraph, rw); + matches.push(ms); + limits.check_limits(iteration, egraph)?; + } + Ok(matches) + } + /// A hook allowing you to customize rewrite application behavior. /// Useful to implement rule management. ///