diff --git a/utils/wasm-gen/src/generator.rs b/utils/wasm-gen/src/generator.rs index ffa44d21a90..0f38d92641c 100644 --- a/utils/wasm-gen/src/generator.rs +++ b/utils/wasm-gen/src/generator.rs @@ -141,7 +141,7 @@ impl<'a, 'b> GearWasmGenerator<'a, 'b> { Ok(if config.remove_recursions { log::trace!("Removing recursions"); - utils::remove_recursion(module) + utils::instrument_recursion(module) } else { log::trace!("Leaving recursions"); module diff --git a/utils/wasm-gen/src/tests.rs b/utils/wasm-gen/src/tests.rs index b37e007401f..02c0f66ccfc 100644 --- a/utils/wasm-gen/src/tests.rs +++ b/utils/wasm-gen/src/tests.rs @@ -51,7 +51,6 @@ const UNSTRUCTURED_SIZE: usize = 1_000_000; fn instrument_recursions() { let wat1 = r#" (module - (func $import0 (import "env" "gr_leave")) (memory $memory0 (import "env" "memory") 16) (export "handle" (func $handle)) (func $handle diff --git a/utils/wasm-gen/src/utils.rs b/utils/wasm-gen/src/utils.rs index dae90ae572e..1afd4a069f1 100644 --- a/utils/wasm-gen/src/utils.rs +++ b/utils/wasm-gen/src/utils.rs @@ -16,12 +16,18 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . +#![allow(dead_code)] + use crate::wasm::PageCount as WasmPageCount; -use gear_wasm_instrument::parity_wasm::{ - builder, - elements::{ - self, BlockType, External, FuncBody, ImportCountType, Instruction, Module, Type, ValueType, +use gear_wasm_instrument::{ + parity_wasm::{ + builder, + elements::{ + self, BlockType, External, FuncBody, ImportCountType, Instruction, Module, Type, + ValueType, + }, }, + syscalls::SysCallName, }; use gsys::HashWithValue; use std::{ @@ -218,7 +224,7 @@ fn find_recursion_impl( path.pop(); } -pub fn instrument_recursion(mut module: Module) -> Module { +pub fn instrument_recursion(module: Module) -> Module { let Some(mem_size) = module .import_section() .and_then(|section| { @@ -236,6 +242,27 @@ pub fn instrument_recursion(mut module: Module) -> Module { let call_depth_ptr = mem_size - mem::size_of::() as u32; + let mut mbuilder = builder::from_module(module); + + // fn gr_leave() -> !; + let import_sig = mbuilder.push_signature(builder::signature().build_sig()); + + mbuilder.push_import( + builder::import() + .module("env") + .field(SysCallName::Leave.to_str()) + .external() + .func(import_sig) + .build(), + ); + + // back to plain module + let mut module = mbuilder.build(); + + let import_count = module.import_count(ImportCountType::Function); + let gr_leave_index @ inserted_index = import_count as u32 - 1; + let inserted_count = 1; + let Some(code_section) = module.code_section_mut() else { return module; }; @@ -243,6 +270,14 @@ pub fn instrument_recursion(mut module: Module) -> Module { for func_body in code_section.bodies_mut() { let instructions = func_body.code_mut().elements_mut(); + for instruction in instructions.iter_mut() { + if let Instruction::Call(call_index) = instruction { + if *call_index >= inserted_index { + *call_index += inserted_count + } + } + } + instructions.splice( 0..0, [ @@ -255,7 +290,7 @@ pub fn instrument_recursion(mut module: Module) -> Module { Instruction::I32Const(0), Instruction::I32Const(0), Instruction::I32Store(2, call_depth_ptr), - Instruction::Unreachable, //TODO: Instruction::Call(gr_leave_index), + Instruction::Call(gr_leave_index), Instruction::End, //call_depth += 1; Instruction::I32Const(0), @@ -282,6 +317,38 @@ pub fn instrument_recursion(mut module: Module) -> Module { ); } + let sections = module.sections_mut(); + sections.retain(|section| !matches!(section, elements::Section::Custom(_))); + + for section in sections { + match section { + elements::Section::Export(export_section) => { + for export in export_section.entries_mut() { + if let elements::Internal::Function(func_index) = export.internal_mut() { + if *func_index >= inserted_index { + *func_index += inserted_count + } + } + } + } + elements::Section::Element(elements_section) => { + for segment in elements_section.entries_mut() { + for func_index in segment.members_mut() { + if *func_index >= inserted_index { + *func_index += inserted_count + } + } + } + } + elements::Section::Start(start_idx) => { + if *start_idx >= inserted_index { + *start_idx += inserted_count + } + } + _ => {} + } + } + module }