diff --git a/src/core/reader/mod.rs b/src/core/reader/mod.rs index 176c3c72..b8cab93b 100644 --- a/src/core/reader/mod.rs +++ b/src/core/reader/mod.rs @@ -212,7 +212,6 @@ pub mod span { self.len } - // TODO is this ok? pub const fn from(&self) -> usize { self.from } diff --git a/src/core/reader/types/mod.rs b/src/core/reader/types/mod.rs index 8f702031..0255c9d7 100644 --- a/src/core/reader/types/mod.rs +++ b/src/core/reader/types/mod.rs @@ -215,7 +215,7 @@ impl WasmReadable for BlockType { // Empty block type let _ = wasm.read_u8().unwrap_validated(); Ok(BlockType::Empty) - } else if let Ok(val_ty) = ValType::read(wasm) { + } else if let Ok(val_ty) = wasm.handle_transaction(|wasm| ValType::read(wasm)) { // No parameters and given valtype as the result Ok(BlockType::Returns(val_ty)) } else { @@ -232,7 +232,7 @@ impl WasmReadable for BlockType { let _ = wasm.read_u8(); BlockType::Empty - } else if let Ok(val_ty) = ValType::read(wasm) { + } else if let Ok(val_ty) = wasm.handle_transaction(|wasm| ValType::read(wasm)) { // No parameters and given valtype as the result BlockType::Returns(val_ty) } else { diff --git a/src/execution/interpreter_loop.rs b/src/execution/interpreter_loop.rs index 4c10c6ca..ae74c031 100644 --- a/src/execution/interpreter_loop.rs +++ b/src/execution/interpreter_loop.rs @@ -16,7 +16,7 @@ use alloc::vec::Vec; use crate::{ assert_validated::UnwrapValidatedExt, core::{ - indices::{DataIdx, FuncIdx, GlobalIdx, LocalIdx}, + indices::{DataIdx, FuncIdx, GlobalIdx, LabelIdx, LocalIdx}, reader::{ types::{memarg::MemArg, BlockType, FuncType}, WasmReadable, WasmReader, @@ -127,6 +127,24 @@ pub(super) fn run( stp += 1; } } + BR_TABLE => { + let label_vec = wasm + .read_vec(|wasm| wasm.read_var_u32().map(|v| v as LabelIdx)) + .unwrap_validated(); + wasm.read_var_u32().unwrap_validated(); + + // TODO is this correct? + let case_val_i32: i32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); + let case_val = case_val_i32 as usize; + + if case_val >= label_vec.len() { + stp += label_vec.len(); + } else { + stp += case_val; + } + + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); + } BR => { //skip n of BR n wasm.read_var_u32().unwrap_validated(); diff --git a/src/validation/code.rs b/src/validation/code.rs index 41fedb3c..575700ac 100644 --- a/src/validation/code.rs +++ b/src/validation/code.rs @@ -267,6 +267,26 @@ fn read_instructions( wasm, label_idx, stack, sidetable, )?; } + BR_TABLE => { + let label_vec = wasm.read_vec(|wasm| wasm.read_var_u32().map(|v| v as LabelIdx))?; + let max_label_idx = wasm.read_var_u32()? as LabelIdx; + stack.assert_pop_val_type(ValType::NumType(NumType::I32))?; + + for label_idx in label_vec { + validate_intrablock_jump_and_generate_sidetable_entry( + wasm, label_idx, stack, sidetable, + )?; + } + + validate_intrablock_jump_and_generate_sidetable_entry( + wasm, + max_label_idx, + stack, + sidetable, + )?; + + stack.make_unspecified()?; + } END => { let (label_info, _) = stack.assert_pop_ctrl()?; let stp_here = sidetable.len(); diff --git a/tests/structured_control_flow/block.rs b/tests/structured_control_flow/block.rs index 24257cc1..b5d754a2 100644 --- a/tests/structured_control_flow/block.rs +++ b/tests/structured_control_flow/block.rs @@ -261,11 +261,11 @@ fn branch_if() { let validation_info = validate(&wasm_bytes).expect("validation failed"); let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); - let abs_fn = instance.get_function_by_index(0, 0).unwrap(); + let switch_case_fn = instance.get_function_by_index(0, 0).unwrap(); - assert_eq!(6, instance.invoke(&abs_fn, 6).unwrap()); - assert_eq!(123, instance.invoke(&abs_fn, -123).unwrap()); - assert_eq!(0, instance.invoke(&abs_fn, 0).unwrap()); + assert_eq!(6, instance.invoke(&switch_case_fn, 6).unwrap()); + assert_eq!(123, instance.invoke(&switch_case_fn, -123).unwrap()); + assert_eq!(0, instance.invoke(&switch_case_fn, 0).unwrap()); } #[test_log::test] @@ -330,3 +330,52 @@ fn recursive_fibonacci() { .collect::>(); assert_eq!(&first_ten, &[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]); } + +#[test_log::test] +fn switch_case() { + let wasm_bytes = wat::parse_str( + r#" + (module + (func $switch_case (param $value i32) (result i32) + (block $default + (block $case4 + (block $case3 + (block $case2 + (block $case1 + local.get $value + (br_table $case1 $case2 $case3 $case4 $default) + ) + i32.const 1 + return + ) + i32.const 3 + return + ) + i32.const 5 + return + ) + i32.const 7 + return + ) + i32.const 9 + return + ) + (export "switch_case" (func $switch_case)) + )"#, + ) + .unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let switch_case_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(9, instance.invoke(&switch_case_fn, -5).unwrap()); + assert_eq!(9, instance.invoke(&switch_case_fn, -1).unwrap()); + assert_eq!(1, instance.invoke(&switch_case_fn, 0).unwrap()); + assert_eq!(3, instance.invoke(&switch_case_fn, 1).unwrap()); + assert_eq!(5, instance.invoke(&switch_case_fn, 2).unwrap()); + assert_eq!(7, instance.invoke(&switch_case_fn, 3).unwrap()); + assert_eq!(9, instance.invoke(&switch_case_fn, 4).unwrap()); + assert_eq!(9, instance.invoke(&switch_case_fn, 7).unwrap()); +}