diff --git a/utils/wasm-gen/src/generator.rs b/utils/wasm-gen/src/generator.rs
index ffa44d21a90..0f38d92641c 100644
--- a/utils/wasm-gen/src/generator.rs
+++ b/utils/wasm-gen/src/generator.rs
@@ -141,7 +141,7 @@ impl<'a, 'b> GearWasmGenerator<'a, 'b> {
Ok(if config.remove_recursions {
log::trace!("Removing recursions");
- utils::remove_recursion(module)
+ utils::instrument_recursion(module)
} else {
log::trace!("Leaving recursions");
module
diff --git a/utils/wasm-gen/src/tests.rs b/utils/wasm-gen/src/tests.rs
index b37e007401f..02c0f66ccfc 100644
--- a/utils/wasm-gen/src/tests.rs
+++ b/utils/wasm-gen/src/tests.rs
@@ -51,7 +51,6 @@ const UNSTRUCTURED_SIZE: usize = 1_000_000;
fn instrument_recursions() {
let wat1 = r#"
(module
- (func $import0 (import "env" "gr_leave"))
(memory $memory0 (import "env" "memory") 16)
(export "handle" (func $handle))
(func $handle
diff --git a/utils/wasm-gen/src/utils.rs b/utils/wasm-gen/src/utils.rs
index dae90ae572e..1afd4a069f1 100644
--- a/utils/wasm-gen/src/utils.rs
+++ b/utils/wasm-gen/src/utils.rs
@@ -16,12 +16,18 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see .
+#![allow(dead_code)]
+
use crate::wasm::PageCount as WasmPageCount;
-use gear_wasm_instrument::parity_wasm::{
- builder,
- elements::{
- self, BlockType, External, FuncBody, ImportCountType, Instruction, Module, Type, ValueType,
+use gear_wasm_instrument::{
+ parity_wasm::{
+ builder,
+ elements::{
+ self, BlockType, External, FuncBody, ImportCountType, Instruction, Module, Type,
+ ValueType,
+ },
},
+ syscalls::SysCallName,
};
use gsys::HashWithValue;
use std::{
@@ -218,7 +224,7 @@ fn find_recursion_impl(
path.pop();
}
-pub fn instrument_recursion(mut module: Module) -> Module {
+pub fn instrument_recursion(module: Module) -> Module {
let Some(mem_size) = module
.import_section()
.and_then(|section| {
@@ -236,6 +242,27 @@ pub fn instrument_recursion(mut module: Module) -> Module {
let call_depth_ptr = mem_size - mem::size_of::() as u32;
+ let mut mbuilder = builder::from_module(module);
+
+ // fn gr_leave() -> !;
+ let import_sig = mbuilder.push_signature(builder::signature().build_sig());
+
+ mbuilder.push_import(
+ builder::import()
+ .module("env")
+ .field(SysCallName::Leave.to_str())
+ .external()
+ .func(import_sig)
+ .build(),
+ );
+
+ // back to plain module
+ let mut module = mbuilder.build();
+
+ let import_count = module.import_count(ImportCountType::Function);
+ let gr_leave_index @ inserted_index = import_count as u32 - 1;
+ let inserted_count = 1;
+
let Some(code_section) = module.code_section_mut() else {
return module;
};
@@ -243,6 +270,14 @@ pub fn instrument_recursion(mut module: Module) -> Module {
for func_body in code_section.bodies_mut() {
let instructions = func_body.code_mut().elements_mut();
+ for instruction in instructions.iter_mut() {
+ if let Instruction::Call(call_index) = instruction {
+ if *call_index >= inserted_index {
+ *call_index += inserted_count
+ }
+ }
+ }
+
instructions.splice(
0..0,
[
@@ -255,7 +290,7 @@ pub fn instrument_recursion(mut module: Module) -> Module {
Instruction::I32Const(0),
Instruction::I32Const(0),
Instruction::I32Store(2, call_depth_ptr),
- Instruction::Unreachable, //TODO: Instruction::Call(gr_leave_index),
+ Instruction::Call(gr_leave_index),
Instruction::End,
//call_depth += 1;
Instruction::I32Const(0),
@@ -282,6 +317,38 @@ pub fn instrument_recursion(mut module: Module) -> Module {
);
}
+ let sections = module.sections_mut();
+ sections.retain(|section| !matches!(section, elements::Section::Custom(_)));
+
+ for section in sections {
+ match section {
+ elements::Section::Export(export_section) => {
+ for export in export_section.entries_mut() {
+ if let elements::Internal::Function(func_index) = export.internal_mut() {
+ if *func_index >= inserted_index {
+ *func_index += inserted_count
+ }
+ }
+ }
+ }
+ elements::Section::Element(elements_section) => {
+ for segment in elements_section.entries_mut() {
+ for func_index in segment.members_mut() {
+ if *func_index >= inserted_index {
+ *func_index += inserted_count
+ }
+ }
+ }
+ }
+ elements::Section::Start(start_idx) => {
+ if *start_idx >= inserted_index {
+ *start_idx += inserted_count
+ }
+ }
+ _ => {}
+ }
+ }
+
module
}