diff --git a/kani-compiler/src/kani_middle/provide.rs b/kani-compiler/src/kani_middle/provide.rs index bdca51f43129..c07cc1e477fe 100644 --- a/kani-compiler/src/kani_middle/provide.rs +++ b/kani-compiler/src/kani_middle/provide.rs @@ -55,7 +55,9 @@ fn run_kani_mir_passes<'tcx>( body: &'tcx Body<'tcx>, ) -> &'tcx Body<'tcx> { tracing::debug!(?def_id, "Run Kani transformation passes"); - stubbing::transform(tcx, def_id, body) + let mut transformed_body = stubbing::transform(tcx, def_id, body); + stubbing::transform_foreign_functions(tcx, &mut transformed_body); + tcx.arena.alloc(transformed_body) } /// Runs a reachability analysis before running the default diff --git a/kani-compiler/src/kani_middle/stubbing/transform.rs b/kani-compiler/src/kani_middle/stubbing/transform.rs index 2a164623ba24..c8c508ad94ac 100644 --- a/kani-compiler/src/kani_middle/stubbing/transform.rs +++ b/kani-compiler/src/kani_middle/stubbing/transform.rs @@ -12,7 +12,12 @@ use lazy_static::lazy_static; use regex::Regex; use rustc_data_structures::fingerprint::Fingerprint; use rustc_hir::{def_id::DefId, definitions::DefPathHash}; -use rustc_middle::{mir::Body, ty::TyCtxt}; +use rustc_index::IndexVec; +use rustc_middle::mir::{ + interpret::ConstValue, visit::MutVisitor, Body, ConstantKind, Local, LocalDecl, Location, + Operand, +}; +use rustc_middle::ty::{self, TyCtxt}; /// Returns the `DefId` of the stub for the function/method identified by the /// parameter `def_id`, and `None` if the function/method is not stubbed. @@ -23,18 +28,58 @@ pub fn get_stub(tcx: TyCtxt, def_id: DefId) -> Option { /// Returns the new body of a function/method if it has been stubbed out; /// otherwise, returns the old body. -pub fn transform<'tcx>( - tcx: TyCtxt<'tcx>, - def_id: DefId, - old_body: &'tcx Body<'tcx>, -) -> &'tcx Body<'tcx> { +pub fn transform<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, old_body: &'tcx Body<'tcx>) -> Body<'tcx> { if let Some(replacement) = get_stub(tcx, def_id) { let new_body = tcx.optimized_mir(replacement).clone(); if check_compatibility(tcx, def_id, old_body, replacement, &new_body) { - return tcx.arena.alloc(new_body); + return new_body; + } + } + old_body.clone() +} + +/// Traverse `body` searching for calls to foreing functions and, whevever there is +/// a stub available, replace the call to the foreign function with a call +/// to its correspondent stub. This happens as a separate step because there is no +/// body available to foreign functions at this stage. +pub fn transform_foreign_functions<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + if let Some(stub_map) = get_stub_mapping(tcx) { + let mut visitor = + ForeignFunctionTransformer { tcx, local_decls: body.clone().local_decls, stub_map }; + visitor.visit_body(body); + } +} + +struct ForeignFunctionTransformer<'tcx> { + /// The compiler context. + tcx: TyCtxt<'tcx>, + /// Local declarations of the callee function. Kani searches here for foreign functions. + local_decls: IndexVec>, + /// Map of functions/methods to their correspondent stubs. + stub_map: HashMap, +} + +impl<'tcx> MutVisitor<'tcx> for ForeignFunctionTransformer<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _location: Location) { + let func_ty = operand.ty(&self.local_decls, self.tcx); + if let ty::FnDef(reachable_function, arguments) = *func_ty.kind() { + if self.tcx.is_foreign_item(reachable_function) { + if let Some(stub) = self.stub_map.get(&reachable_function) { + let Operand::Constant(function_definition) = operand else { + return; + }; + function_definition.literal = ConstantKind::from_value( + ConstValue::ZeroSized, + self.tcx.type_of(stub).instantiate(self.tcx, arguments), + ); + } + } } } - old_body } /// Checks whether the stub is compatible with the original function/method: do diff --git a/tests/.gitignore b/tests/.gitignore index 55cf0ad46a3b..7f948a83337c 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1,11 +1,11 @@ -*.goto +# Temporary files and folders *.json - -# Temporary folders -rmet*/ kani_concrete_playback +rmet*/ +target/ # Binary artifacts +*.goto *out smoke check_tests diff --git a/tests/kani/Stubbing/foreign_functions.rs b/tests/kani/Stubbing/foreign_functions.rs new file mode 100644 index 000000000000..8c2ff0812905 --- /dev/null +++ b/tests/kani/Stubbing/foreign_functions.rs @@ -0,0 +1,76 @@ +// Copyright Kani Contributors +// SPDX-License-Identifier: Apache-2.0 OR MIT +// +// kani-flags: --enable-unstable --enable-stubbing +// +//! Check support for stubbing out foreign functions. + +#![feature(rustc_private)] +extern crate libc; + +use libc::c_char; +use libc::c_int; +use libc::c_longlong; +use libc::size_t; + +#[allow(dead_code)] // Avoid warning when using stubs. +#[allow(unused_variables)] +mod stubs { + use super::*; + + pub unsafe extern "C" fn strlen(cs: *const c_char) -> size_t { + 4 + } + + pub unsafe extern "C" fn sysconf(_input: c_int) -> c_longlong { + 10 + } +} + +fn dig_deeper(input: c_int) { + unsafe { + type FunctionPointerType = unsafe extern "C" fn(c_int) -> c_longlong; + let ptr: FunctionPointerType = libc::sysconf; + assert_eq!(ptr(input) as usize, 10); + } +} + +fn deeper_call() { + dig_deeper(libc::_SC_PAGESIZE) +} + +fn function_pointer_call(function_pointer: unsafe extern "C" fn(c_int) -> c_longlong) { + assert_eq!(unsafe { function_pointer(libc::_SC_PAGESIZE) } as usize, 10); +} + +#[kani::proof] +#[kani::stub(libc::strlen, stubs::strlen)] +fn standard() { + let str: Box = Box::new(4); + let str_ptr: *const i8 = &*str; + assert_eq!(unsafe { libc::strlen(str_ptr) }, 4); +} + +#[kani::proof] +#[kani::stub(libc::strlen, stubs::strlen)] +fn function_pointer_standard() { + let str: Box = Box::new(4); + let str_ptr: *const i8 = &*str; + let new_ptr = libc::strlen; + assert_eq!(unsafe { new_ptr(str_ptr) }, 4); +} + +#[kani::proof] +#[kani::stub(libc::sysconf, stubs::sysconf)] +fn function_pointer_with_layers() { + deeper_call(); +} + +#[kani::proof] +#[kani::stub(libc::sysconf, stubs::sysconf)] +fn function_pointer_as_parameter() { + type FunctionPointerType = unsafe extern "C" fn(c_int) -> c_longlong; + let function_pointer: FunctionPointerType = libc::sysconf; + function_pointer_call(function_pointer); + function_pointer_call(libc::sysconf); +}