Skip to content

Commit

Permalink
add sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed Mar 5, 2024
1 parent 4c458ca commit b3c5df7
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 11 deletions.
60 changes: 60 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ tracing-subscriber = "0.3.17"
lazy_static = "1.4.0"
serde = "1.0.189"
serde_json = "1.0.107"
rand = "0.8.5"

[package.metadata.rust-analyzer]
rustc_private = true
51 changes: 40 additions & 11 deletions src/bin/nopcrat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{

use clap::Parser;
use nopcrat::*;
use rand::prelude::*;

#[derive(Parser, Debug)]
struct Args {
Expand All @@ -27,6 +28,12 @@ struct Args {
#[arg(short, long)]
size: bool,
#[arg(long)]
sample_negative: bool,
#[arg(long)]
sample_may: bool,
#[arg(long)]
sample_must: bool,
#[arg(long)]
time: bool,
#[arg(long)]
function_times: Option<usize>,
Expand Down Expand Up @@ -83,7 +90,7 @@ fn main() {
print_functions: args.print_function.into_iter().collect(),
function_times: args.function_times,
};
let analysis_result = if let Some(dump_file) = args.use_analysis_result {
let analysis_result = if let Some(dump_file) = &args.use_analysis_result {
let dump_file = File::open(dump_file).unwrap();
serde_json::from_reader(dump_file).unwrap()
} else {
Expand All @@ -99,16 +106,38 @@ fn main() {
}
}

let fns = analysis_result.len();
let musts = analysis_result
.values()
.map(|v| v.iter().filter(|p| p.must).count())
.sum::<usize>();
let mays = analysis_result
.values()
.map(|v| v.iter().filter(|p| !p.must).count())
.sum::<usize>();
println!("{} {} {}", fns, musts, mays);
if args.sample_negative {
let mut fns = sampling::sample_from_path(path, &analysis_result);
fns.shuffle(&mut thread_rng());
for f in fns.iter().take(10) {
println!("{:?}", f);
}
return;
}
if args.sample_may || args.sample_must {
let mut params: Vec<_> = analysis_result
.iter()
.filter(|(_, params)| params.iter().any(|p| p.must == args.sample_must))
.collect();
params.shuffle(&mut thread_rng());
for (f, ps) in params.iter().take(10) {
println!("{}\n{:?}", f, ps);
}
return;
}

if args.use_analysis_result.is_none() {
let fns = analysis_result.len();
let musts = analysis_result
.values()
.map(|v| v.iter().filter(|p| p.must).count())
.sum::<usize>();
let mays = analysis_result
.values()
.map(|v| v.iter().filter(|p| !p.must).count())
.sum::<usize>();
println!("{} {} {}", fns, musts, mays);
}

if let Some(dump_file) = args.dump_analysis_result {
let dump_file = File::create(dump_file).unwrap();
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ extern crate rustc_type_ir;
pub mod ai;
pub mod compile_util;
pub mod graph;
pub mod sampling;
pub mod size;
pub mod transform;
93 changes: 93 additions & 0 deletions src/sampling.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::path::Path;

use rustc_ast::Mutability;
use rustc_const_eval::interpret::ConstValue;
use rustc_middle::{
mir::{visit::Visitor, Body, ConstantKind},
ty::{TyCtxt, TyKind, TypeAndMut},
};
use rustc_session::config::Input;

use crate::{ai::analysis::AnalysisResult, compile_util};

pub fn sample_from_path(path: &Path, res: &AnalysisResult) -> Vec<String> {
sample_from_input(compile_util::path_to_input(path), res)
}

pub fn sample_from_code(code: &str, res: &AnalysisResult) -> Vec<String> {
sample_from_input(compile_util::str_to_input(code), res)
}

fn sample_from_input(input: Input, res: &AnalysisResult) -> Vec<String> {
let config = compile_util::make_config(input);
compile_util::run_compiler(config, |tcx| {
let hir = tcx.hir();
let mut fns = vec![];
for id in hir.items() {
let item = hir.item(id);
if item.ident.name.to_ident_string() == "main" {
continue;
}
if !matches!(item.kind, rustc_hir::ItemKind::Fn(_, _, _)) {
continue;
}
let def_id = id.owner_id.to_def_id();
let body = tcx.optimized_mir(def_id);
let name = tcx.def_path_str(def_id);
if !res.contains_key(&name)
&& body
.args_iter()
.any(|arg| match body.local_decls[arg].ty.kind() {
TyKind::RawPtr(TypeAndMut {
ty,
mutbl: Mutability::Mut,
}) => !ty.is_primitive() && !ty.is_c_void(tcx) && !ty.is_any_ptr(),
_ => false,
})
// && has_call(body, tcx)
{
fns.push(name);
}
}
fns
})
.unwrap()
}

#[allow(unused)]
fn has_call<'tcx>(body: &Body<'tcx>, tcx: TyCtxt<'tcx>) -> bool {
let mut visitor = CallVisitor { tcx, b: false };
visitor.visit_body(body);
visitor.b
}

struct CallVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
b: bool,
}

impl<'tcx> Visitor<'tcx> for CallVisitor<'tcx> {
fn visit_terminator(
&mut self,
terminator: &rustc_middle::mir::Terminator<'tcx>,
location: rustc_middle::mir::Location,
) {
if let rustc_middle::mir::TerminatorKind::Call { func, .. } = &terminator.kind {
if let Some(constant) = func.constant() {
if let ConstantKind::Val(ConstValue::ZeroSized, ty) = constant.literal {
if let TyKind::FnDef(def_id, _) = ty.kind() {
let name = self.tcx.def_path(*def_id).to_string_no_crate_verbose();
if name.contains("{extern#")
&& (name.contains("cpy")
|| name.contains("set")
|| name.contains("move"))
{
self.b = true;
}
}
}
}
}
self.super_terminator(terminator, location);
}
}

0 comments on commit b3c5df7

Please sign in to comment.