diff --git a/support/code-coverage/src/lib.rs b/support/code-coverage/src/lib.rs index c69c3001a..979e28b08 100644 --- a/support/code-coverage/src/lib.rs +++ b/support/code-coverage/src/lib.rs @@ -79,6 +79,7 @@ pub fn analyze_files(rust_files: &[PathBuf], workspace_root: &Path) -> Vec Vec>(); custom_println!("[code-coverage]", green, "found {} tests", tests.len()); + custom_println!("[code-coverage]", green, "found {} benchmarks", benchmarks.len()); custom_println!( "[code-coverage]", @@ -113,6 +115,17 @@ pub fn analyze_files(rust_files: &[PathBuf], workspace_root: &Path) -> Vec>(); coverage.par_sort_by_key(|(_, v)| *v); @@ -207,7 +220,7 @@ pub fn find_tests(rust_files: &[PathBuf]) -> Vec { .into_iter() .map(|f| { let mut method_calls = HashSet::new(); - let mut visitor = CallVisitor { + let mut visitor = MethodCallVisitor { method_calls: &mut method_calls, }; visitor.visit_item_fn(&f); @@ -225,11 +238,65 @@ pub fn find_tests(rust_files: &[PathBuf]) -> Vec { }) } +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct BenchmarkInfo { + pub calls: HashSet, +} + +/// Finds all benchmarks in the given set of rust files, using a parallel map-reduce +pub fn find_benchmarks(rust_files: &[PathBuf]) -> Vec { + rust_files + .par_iter() + .map(|path| { + let Ok(content) = fs::read_to_string(path) else { + return Vec::new(); + }; + let Ok(file) = syn::parse_file(&content) else { + return Vec::new(); + }; + let mut visitor = BenchmarkVisitor { benchmarks: Vec::new() }; + visitor.visit_file(&file); + visitor + .benchmarks + .into_iter() + .map(|f| { + let mut calls = HashSet::new(); + let mut visitor = CallVisitor { + calls: &mut calls, + }; + visitor.visit_item_fn(&f); + BenchmarkInfo { + calls, + } + }) + .collect() + }) + .reduce(Vec::new, |mut acc, mut infos| { + acc.append(&mut infos); + acc + }) +} + pub struct CallVisitor<'a> { - pub method_calls: &'a mut HashSet, + pub calls: &'a mut HashSet, } impl<'ast> Visit<'ast> for CallVisitor<'_> { + fn visit_expr_call(&mut self, i: &'ast syn::ExprCall) { + if let syn::Expr::Path(expr) = &*i.func { + if let Some(seg) = expr.path.segments.last() { + self.calls.insert(seg.ident.to_string()); + } + } + syn::visit::visit_expr_call(self, i); + } +} + +pub struct MethodCallVisitor<'a> { + pub method_calls: &'a mut HashSet, +} + +impl<'ast> Visit<'ast> for MethodCallVisitor<'_> { fn visit_expr_method_call(&mut self, i: &'ast syn::ExprMethodCall) { self.method_calls.insert(i.method.to_string()); syn::visit::visit_expr_method_call(self, i); @@ -260,6 +327,24 @@ impl<'ast> Visit<'ast> for TestVisitor { } } +pub struct BenchmarkVisitor { + pub benchmarks: Vec, +} + +impl<'ast> Visit<'ast> for BenchmarkVisitor { + fn visit_item_fn(&mut self, item_fn: &'ast ItemFn) { + if item_fn.attrs.iter().any(|attr| { + let Some(seg) = attr.path().segments.last() else { + return false; + }; + seg.ident == "benchmark" + }) { + self.benchmarks.push(item_fn.clone()); + } + syn::visit::visit_item_fn(self, item_fn); + } +} + /// Tries to parse a pallet from a module pub fn try_parse_pallet(item_mod: &ItemMod, file_path: &Path, root_path: &Path) -> Option { simulate_manifest_dir("pallets/subtensor", || -> Option {