diff --git a/src/core/error.rs b/src/core/error.rs index 256ed4e3..8d7c16af 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -2,7 +2,7 @@ use alloc::format; use alloc::string::{String, ToString}; use crate::core::indices::GlobalIdx; -use crate::validation_stack::{LabelKind, ValidationStackEntry}; +use crate::validation_stack::ValidationStackEntry; use crate::RefType; use core::fmt::{Display, Formatter}; use core::str::Utf8Error; @@ -47,6 +47,7 @@ pub enum Error { InvalidNumType, InvalidVecType, InvalidFuncType, + InvalidFuncTypeIdx, InvalidRefType, InvalidValType, InvalidExportDesc(u8), @@ -67,13 +68,15 @@ pub enum Error { InvalidGlobalIdx(GlobalIdx), GlobalIsConst, RuntimeError(RuntimeError), - FoundLabel(LabelKind), - FoundUnspecifiedValTypes, MemoryIsNotDefined(MemIdx), // mem.align, wanted alignment ErroneousAlignment(u32, u32), NoDataSegments, DataSegmentNotFound(DataIdx), + InvalidLabelIdx(usize), + ValidationCtrlStackEmpty, + ElseWithoutMatchingIf, + IfWithoutMatchingElse, UnknownTable, TableIsNotDefined(TableIdx), ElementIsNotDefined(ElemIdx), @@ -115,6 +118,9 @@ impl Display for Error { Error::InvalidFuncType => { f.write_str("An invalid byte was read where a functype was expected") } + Error::InvalidFuncTypeIdx => { + f.write_str("An invalid index to the fuctypes table was read") + } Error::InvalidRefType => { f.write_str("An invalid byte was read where a reftype was expected") } @@ -160,10 +166,6 @@ impl Display for Error { )), Error::GlobalIsConst => f.write_str("A const global cannot be written to"), Error::RuntimeError(err) => err.fmt(f), - Error::FoundLabel(lk) => f.write_fmt(format_args!( - "Expecting a ValType, a Label was found: {lk:?}" - )), - Error::FoundUnspecifiedValTypes => f.write_str("Found UnspecifiedValTypes"), Error::ExpectedAnOperand => f.write_str("Expected a ValType"), // Error => f.write_str("Expected an operand (ValType) on the stack") Error::MemoryIsNotDefined(memidx) => f.write_fmt(format_args!( "C.mems[{}] is NOT defined when it should be", @@ -179,6 +181,18 @@ impl Display for Error { Error::DataSegmentNotFound(data_idx) => { f.write_fmt(format_args!("Data Segment {} not found", data_idx)) } + Error::InvalidLabelIdx(label_idx) => { + f.write_fmt(format_args!("invalid label index {}", label_idx)) + } + Error::ValidationCtrlStackEmpty => { + f.write_str("cannot retrieve last ctrl block, validation ctrl stack is empty") + } + Error::ElseWithoutMatchingIf => { + f.write_str("read 'else' without a previous matching 'if' instruction") + } + Error::IfWithoutMatchingElse => { + f.write_str("read 'end' without matching 'else' instruction to 'if' instruction") + } Error::UnknownTable => f.write_str("Unknown Table"), Error::TableIsNotDefined(table_idx) => f.write_fmt(format_args!( "C.tables[{}] is NOT defined when it should be", diff --git a/src/core/mod.rs b/src/core/mod.rs index 92d3b701..90631750 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -2,3 +2,4 @@ pub mod error; pub mod indices; pub mod reader; +pub mod sidetable; diff --git a/src/core/reader/mod.rs b/src/core/reader/mod.rs index f263efbd..b8cab93b 100644 --- a/src/core/reader/mod.rs +++ b/src/core/reader/mod.rs @@ -211,6 +211,10 @@ pub mod span { pub const fn len(&self) -> usize { self.len } + + pub const fn from(&self) -> usize { + self.from + } } impl<'a> Index for WasmReader<'a> { diff --git a/src/core/reader/types/element.rs b/src/core/reader/types/element.rs index 1e98d7e8..9c3216ef 100644 --- a/src/core/reader/types/element.rs +++ b/src/core/reader/types/element.rs @@ -132,7 +132,7 @@ impl ElemType { use crate::validation_stack::ValidationStackEntry::*; - if let Some(val) = valid_stack.peek_stack() { + if let Some(val) = valid_stack.peek_const_validation_stack() { if let Val(val) = val { match val { crate::ValType::RefType(_) => {} diff --git a/src/core/reader/types/mod.rs b/src/core/reader/types/mod.rs index 06727588..d2720f12 100644 --- a/src/core/reader/types/mod.rs +++ b/src/core/reader/types/mod.rs @@ -234,6 +234,85 @@ impl WasmReadable for FuncType { } } +/// +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BlockType { + Empty, + Returns(ValType), + Type(u32), +} + +impl WasmReadable for BlockType { + fn read(wasm: &mut WasmReader) -> Result { + if wasm.peek_u8()? as i8 == 0x40 { + // Empty block type + let _ = wasm.read_u8().unwrap_validated(); + Ok(BlockType::Empty) + } else if let Ok(val_ty) = wasm.handle_transaction(|wasm| ValType::read(wasm)) { + // No parameters and given valtype as the result + Ok(BlockType::Returns(val_ty)) + } else { + // An index to a function type + wasm.read_var_i33() + .and_then(|idx| idx.try_into().map_err(|_| Error::InvalidFuncTypeIdx)) + .map(BlockType::Type) + } + } + + fn read_unvalidated(wasm: &mut WasmReader) -> Self { + if wasm.peek_u8().unwrap_validated() as i8 == 0x40 { + // Empty block type + let _ = wasm.read_u8(); + + BlockType::Empty + } else if let Ok(val_ty) = wasm.handle_transaction(|wasm| ValType::read(wasm)) { + // No parameters and given valtype as the result + BlockType::Returns(val_ty) + } else { + // An index to a function type + BlockType::Type( + wasm.read_var_i33() + .unwrap_validated() + .try_into() + .unwrap_validated(), + ) + } + } +} + +impl BlockType { + pub fn as_func_type(&self, func_types: &[FuncType]) -> Result { + match self { + BlockType::Empty => Ok(FuncType { + params: ResultType { + valtypes: Vec::new(), + }, + returns: ResultType { + valtypes: Vec::new(), + }, + }), + BlockType::Returns(val_type) => Ok(FuncType { + params: ResultType { + valtypes: Vec::new(), + }, + returns: ResultType { + valtypes: [*val_type].into(), + }, + }), + BlockType::Type(type_idx) => { + let type_idx: usize = (*type_idx) + .try_into() + .map_err(|_| Error::InvalidFuncTypeIdx)?; + + func_types + .get(type_idx) + .cloned() + .ok_or(Error::InvalidFuncTypeIdx) + } + } + } +} + #[derive(Copy, Clone, PartialEq, Eq)] pub struct Limits { pub min: u32, diff --git a/src/core/reader/types/opcode.rs b/src/core/reader/types/opcode.rs index aa0e1a26..6a82bde2 100644 --- a/src/core/reader/types/opcode.rs +++ b/src/core/reader/types/opcode.rs @@ -1,6 +1,17 @@ +//! All opcodes, in alphanumerical order by their numeric (hex-)value pub const UNREACHABLE: u8 = 0x00; pub const NOP: u8 = 0x01; +pub const BLOCK: u8 = 0x02; +pub const LOOP: u8 = 0x03; +pub const IF: u8 = 0x04; +#[allow(unused)] // TODO remove this once sidetable implementation lands +pub const ELSE: u8 = 0x05; pub const END: u8 = 0x0B; +pub const BR: u8 = 0x0C; +#[allow(unused)] // TODO remove this once sidetable implementation lands +pub const BR_IF: u8 = 0x0D; +#[allow(unused)] // TODO remove this once sidetable implementation lands +pub const BR_TABLE: u8 = 0x0E; pub const RETURN: u8 = 0x0F; pub const CALL: u8 = 0x10; pub const DROP: u8 = 0x1A; diff --git a/src/core/reader/types/values.rs b/src/core/reader/types/values.rs index 28780bf5..6793739c 100644 --- a/src/core/reader/types/values.rs +++ b/src/core/reader/types/values.rs @@ -67,6 +67,27 @@ impl WasmReader<'_> { Ok(result) } + pub fn read_var_i33(&mut self) -> Result { + let mut result: i64 = 0; + let mut shift: u64 = 0; + + let mut byte: i64; + loop { + byte = self.read_u8()? as i64; + result |= (byte & 0b0111_1111) << shift; + shift += 7; + if (byte & 0b1000_0000) == 0 { + break; + } + } + + if shift < 33 && (byte & 0x40 != 0) { + result |= !0 << shift; + } + + Ok(result) + } + pub fn read_var_f32(&mut self) -> Result { if self.full_wasm_binary.len() - self.pc < 4 { return Err(Error::Eof); diff --git a/src/core/sidetable.rs b/src/core/sidetable.rs new file mode 100644 index 00000000..0a23349e --- /dev/null +++ b/src/core/sidetable.rs @@ -0,0 +1,65 @@ +//! This module contains a data structure to allow in-place interpretation +//! +//! Control-flow in WASM is denoted in labels. To avoid linear search through the WASM binary or +//! stack for the respective label of a branch, a sidetable is generated during validation, which +//! stores the offset on the current instruction pointer for the branch. A sidetable entry hence +//! allows to translate the implicit control flow information ("jump to the next `else`") to +//! explicit modifications of the instruction pointer (`instruction_pointer += 13`). +//! +//! Branches in WASM can only go outwards, they either `break` out of a block or `continue` to the +//! head of a loob block. Put differently, a label can only be referenced from within its +//! associated structured control instruction. +//! +//! Noteworthy, branching can also have side-effects on the operand stack: +//! +//! - Taking a branch unwinds the operand stack, down to where the targeted structured control flow +//! instruction was entered. [`SidetableEntry::popcnt`] holds information on how many values to +//! pop from the operand stack when a branch is taken. +//! - When a branch is taken, it may consume arguments from the operand stack. These are pushed +//! back on the operand stack after unwinding. This behavior can be emulated by copying the +//! uppermost [`SidetableEntry::valcnt`] operands on the operand stack before taking a branch +//! into a structured control instruction. +//! +//! # Reference +//! +//! - Core / Syntax / Instructions / Control Instructions, WASM Spec, +//! +//! - "A fast in-place interpreter for WebAssembly", Ben L. Titzer, +//! + +use alloc::vec::Vec; + +// A sidetable + +pub type Sidetable = Vec; + +/// Entry to translate the current branches implicit target into an explicit offset to the instruction pointer, as well as the side table pointer +/// +/// Each of the following constructs requires a [`SidetableEntry`]: +/// +/// - br +/// - br_if +/// - br_table +/// - else +// TODO hide implementation +// TODO Remove Clone trait from sidetables +#[derive(Clone)] +pub struct SidetableEntry { + /// Δpc: the amount to adjust the instruction pointer by if the branch is taken + pub delta_pc: isize, + + /// Δstp: the amount to adjust the side-table index by if the branch is taken + pub delta_stp: isize, + + /// valcnt: the number of values that will be copied if the branch is taken + /// + /// Branches may additionally consume operands themselves, which they push back on the operand + /// stack after unwinding. + pub valcnt: usize, + + /// popcnt: the number of values that will be popped if the branch is taken + /// + /// Taking a branch unwinds the operand stack down to the height where the targeted structured + /// control instruction was entered. + pub popcnt: usize, +} diff --git a/src/execution/interpreter_loop.rs b/src/execution/interpreter_loop.rs index ef376387..5ccfc00a 100644 --- a/src/execution/interpreter_loop.rs +++ b/src/execution/interpreter_loop.rs @@ -16,11 +16,12 @@ use alloc::vec::Vec; use crate::{ assert_validated::UnwrapValidatedExt, core::{ - indices::{DataIdx, FuncIdx, GlobalIdx, LocalIdx, TableIdx, TypeIdx}, + indices::{DataIdx, FuncIdx, GlobalIdx, LabelIdx, LocalIdx, TableIdx, TypeIdx}, reader::{ - types::{memarg::MemArg, FuncType}, + types::{memarg::MemArg, BlockType, FuncType}, WasmReadable, WasmReader, }, + sidetable::Sidetable, }, locals::Locals, store::{DataInst, Store}, @@ -48,6 +49,11 @@ pub(super) fn run( // Start reading the function's instructions let mut wasm = WasmReader::new(wasm_bytecode); + // the sidetable and stp for this function, stp will reset to 0 every call + // since function instances have their own sidetable. + let mut current_sidetable: &Sidetable = &func_inst.sidetable; + let mut stp = 0; + // unwrap is sound, because the validation assures that the function points to valid subslice of the WASM binary wasm.move_start_to(func_inst.code_expr).unwrap(); @@ -64,7 +70,20 @@ pub(super) fn run( trace!("Instruction: NOP"); } END => { - let maybe_return_address = stack.pop_stackframe(); + // if this is not the very last instruction in the function + // just skip because it is a delimiter of a ctrl block + + // TODO there is definitely a better to write this + let current_func_span = store + .funcs + .get(stack.current_stackframe().func_idx) + .unwrap_validated() + .code_expr; + if wasm.pc != current_func_span.from() + current_func_span.len() { + continue; + } + + let (maybe_return_address, maybe_return_stp) = stack.pop_stackframe(); // We finished this entire invocation if there is no stackframe left. If there are // one or more stack frames, we need to continue from where the callee was called @@ -75,30 +94,68 @@ pub(super) fn run( trace!("end of function reached, returning to previous stack frame"); wasm.pc = maybe_return_address; + stp = maybe_return_stp; + + current_sidetable = &store + .funcs + .get(stack.current_stackframe().func_idx) + .unwrap_validated() + .sidetable; } - RETURN => { - trace!("returning from function"); + IF => { + wasm.read_var_u32().unwrap_validated(); - let func_to_call_idx = stack.current_stackframe().func_idx; + let test_val: i32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); - let func_to_call_inst = store.funcs.get(func_to_call_idx).unwrap_validated(); - let func_to_call_ty = types.get(func_to_call_inst.ty).unwrap_validated(); + if test_val != 0 { + stp += 1; + } else { + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); + } + } + ELSE => { + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); + } + BR_IF => { + wasm.read_var_u32().unwrap_validated(); - let ret_vals = stack - .pop_tail_iter(func_to_call_ty.returns.valtypes.len()) - .collect::>(); - stack.clear_callframe_values(); + let test_val: i32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); - for val in ret_vals { - stack.push_value(val); + if test_val != 0 { + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); + } else { + stp += 1; } + } + BR_TABLE => { + let label_vec = wasm + .read_vec(|wasm| wasm.read_var_u32().map(|v| v as LabelIdx)) + .unwrap_validated(); + wasm.read_var_u32().unwrap_validated(); - if stack.callframe_count() == 1 { - break; + // TODO is this correct? + let case_val_i32: i32 = stack.pop_value(ValType::NumType(NumType::I32)).into(); + let case_val = case_val_i32 as usize; + + if case_val >= label_vec.len() { + stp += label_vec.len(); + } else { + stp += case_val; } - trace!("end of function reached, returning to previous stack frame"); - wasm.pc = stack.pop_stackframe(); + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); + } + BR => { + //skip n of BR n + wasm.read_var_u32().unwrap_validated(); + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); + } + BLOCK | LOOP => { + BlockType::read_unvalidated(&mut wasm); + } + RETURN => { + //same as BR, except no need to skip n of BR n + do_sidetable_control_transfer(&mut wasm, stack, &mut stp, current_sidetable); } CALL => { let func_to_call_idx = wasm.read_var_u32().unwrap_validated() as FuncIdx; @@ -111,10 +168,12 @@ pub(super) fn run( trace!("Instruction: call [{func_to_call_idx:?}]"); let locals = Locals::new(params, remaining_locals); - stack.push_stackframe(func_to_call_idx, func_to_call_ty, locals, wasm.pc); + stack.push_stackframe(func_to_call_idx, func_to_call_ty, locals, wasm.pc, stp); wasm.move_start_to(func_to_call_inst.code_expr) .unwrap_validated(); + stp = 0; + current_sidetable = &func_to_call_inst.sidetable; } CALL_INDIRECT => { let type_idx = wasm.read_var_u32().unwrap_validated() as TypeIdx; @@ -159,10 +218,12 @@ pub(super) fn run( trace!("Instruction: call_indirect [{func_addr:?}]"); let locals = Locals::new(params, remaining_locals); - stack.push_stackframe(func_addr.unwrap_validated(), func_ty, locals, wasm.pc); + stack.push_stackframe(func_addr.unwrap_validated(), func_ty, locals, wasm.pc, stp); wasm.move_start_to(func_to_call_inst.code_expr) .unwrap_validated(); + stp = 0; + current_sidetable = &func_to_call_inst.sidetable; } DROP => { stack.drop_value(); @@ -2468,3 +2529,26 @@ pub(super) fn run( } Ok(()) } + +//helper function for avoiding code duplication at intraprocedural jumps +fn do_sidetable_control_transfer( + wasm: &mut WasmReader, + stack: &mut Stack, + current_stp: &mut usize, + current_sidetable: &Sidetable, +) { + let sidetable_entry = ¤t_sidetable[*current_stp]; + + // TODO fix this corner cutting implementation + let jump_vals = stack + .pop_tail_iter(sidetable_entry.valcnt) + .collect::>(); + stack.pop_n_values(sidetable_entry.popcnt); + + for val in jump_vals { + stack.push_value(val); + } + + *current_stp = (*current_stp as isize + sidetable_entry.delta_stp) as usize; + wasm.pc = (wasm.pc as isize + sidetable_entry.delta_pc) as usize; +} diff --git a/src/execution/mod.rs b/src/execution/mod.rs index b77b709f..ba208524 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -179,7 +179,7 @@ where // setting `usize::MAX` as return address for the outermost function ensures that we // observably fail upon errornoeusly continuing execution after that function returns. - stack.push_stackframe(func_idx, func_ty, locals, usize::MAX); + stack.push_stackframe(func_idx, func_ty, locals, usize::MAX, usize::MAX); // Run the interpreter run( @@ -232,7 +232,7 @@ where // Prepare a new stack with the locals for the entry function let mut stack = Stack::new(); let locals = Locals::new(params.into_iter(), func_inst.locals.iter().cloned()); - stack.push_stackframe(func_idx, func_ty, locals, 0); + stack.push_stackframe(func_idx, func_ty, locals, 0, 0); // Run the interpreter run( @@ -325,7 +325,7 @@ where functions .zip(func_blocks) - .map(|(ty, func)| { + .map(|(ty, (func, sidetable))| { wasm_reader .move_start_to(*func) .expect("function index to be in the bounds of the WASM binary"); @@ -342,6 +342,8 @@ where ty: *ty, locals, code_expr, + // TODO fix this ugly clone + sidetable: sidetable.clone(), } }) .collect() diff --git a/src/execution/store.rs b/src/execution/store.rs index 0a90afce..c3ad3007 100644 --- a/src/execution/store.rs +++ b/src/execution/store.rs @@ -6,6 +6,7 @@ use crate::core::indices::TypeIdx; use crate::core::reader::span::Span; use crate::core::reader::types::global::Global; use crate::core::reader::types::{MemType, TableType, ValType}; +use crate::core::sidetable::Sidetable; use crate::execution::value::{Ref, Value}; use crate::RefType; @@ -44,6 +45,7 @@ pub struct FuncInst { pub ty: TypeIdx, pub locals: Vec, pub code_expr: Span, + pub sidetable: Sidetable, } #[derive(Debug)] diff --git a/src/execution/value_stack.rs b/src/execution/value_stack.rs index 7bedb230..5fe5a8c1 100644 --- a/src/execution/value_stack.rs +++ b/src/execution/value_stack.rs @@ -154,12 +154,13 @@ impl Stack { self.frames.last_mut().unwrap_validated() } - /// Pop a [`CallFrame`] from the call stack, returning the return address - pub fn pop_stackframe(&mut self) -> usize { + /// Pop a [`CallFrame`] from the call stack, returning the return address and the return stp + pub fn pop_stackframe(&mut self) -> (usize, usize) { let CallFrame { return_addr, value_stack_base_idx, return_value_count, + return_stp, .. } = self.frames.pop().unwrap_validated(); @@ -172,7 +173,7 @@ impl Stack { "after a function call finished, the stack must have exactly as many values as it had before calling the function plus the number of function return values" ); - return_addr + (return_addr, return_stp) } /// Push a stackframe to the call stack @@ -184,6 +185,7 @@ impl Stack { func_ty: &FuncType, locals: Locals, return_addr: usize, + return_stp: usize, ) { self.frames.push(CallFrame { func_idx, @@ -191,6 +193,7 @@ impl Stack { return_addr, value_stack_base_idx: self.values.len(), return_value_count: func_ty.returns.valtypes.len(), + return_stp, }) } @@ -210,10 +213,16 @@ impl Stack { } /// Clear all of the values pushed to the value stack by the current stack frame + #[allow(unused)] // TODO remove this once sidetable implementation lands pub fn clear_callframe_values(&mut self) { self.values .truncate(self.current_stackframe().value_stack_base_idx); } + + // TODO change this interface + pub fn pop_n_values(&mut self, n: usize) { + self.values.truncate(self.values.len() - n); + } } /// The [WASM spec](https://webassembly.github.io/spec/core/exec/runtime.html#stack) calls this `Activations`, however it refers to the call frames of functions. @@ -232,4 +241,7 @@ pub(crate) struct CallFrame { /// Number of return values to retain on [`Stack::values`] when unwinding/popping a [`CallFrame`] pub return_value_count: usize, + + // Value that the stp has to be set to when this function returns + pub return_stp: usize, } diff --git a/src/validation/code.rs b/src/validation/code.rs index 7552429c..0f580636 100644 --- a/src/validation/code.rs +++ b/src/validation/code.rs @@ -4,16 +4,17 @@ use alloc::vec::Vec; use core::iter; use crate::core::indices::{ - DataIdx, ElemIdx, FuncIdx, GlobalIdx, LocalIdx, MemIdx, TableIdx, TypeIdx, + DataIdx, ElemIdx, FuncIdx, GlobalIdx, LabelIdx, LocalIdx, MemIdx, TableIdx, TypeIdx, }; use crate::core::reader::section_header::{SectionHeader, SectionTy}; use crate::core::reader::span::Span; use crate::core::reader::types::element::ElemType; use crate::core::reader::types::global::Global; use crate::core::reader::types::memarg::MemArg; -use crate::core::reader::types::{FuncType, MemType, NumType, TableType, ValType}; +use crate::core::reader::types::{BlockType, FuncType, MemType, NumType, TableType, ValType}; use crate::core::reader::{WasmReadable, WasmReader}; -use crate::validation_stack::ValidationStack; +use crate::core::sidetable::{Sidetable, SidetableEntry}; +use crate::validation_stack::{LabelInfo, ValidationStack}; use crate::{Error, RefType, Result}; #[allow(clippy::too_many_arguments)] @@ -28,10 +29,11 @@ pub fn validate_code_section( tables: &[TableType], elements: &[ElemType], referenced_functions: &BTreeSet, -) -> Result> { +) -> Result> { assert_eq!(section_header.ty, SectionTy::Code); - let code_block_spans = wasm.read_vec_enumerated(|wasm, idx| { + // TODO replace with single sidetable per module + let code_block_spans_sidetables = wasm.read_vec_enumerated(|wasm, idx| { let ty_idx = type_idx_of_fn[idx]; let func_ty = fn_types[ty_idx].clone(); @@ -45,12 +47,13 @@ pub fn validate_code_section( params.chain(declared_locals).collect::>() }; - let mut stack = ValidationStack::new(); + let mut stack = ValidationStack::new_for_func(func_ty); + let mut sidetable: Sidetable = Sidetable::default(); read_instructions( - idx, wasm, &mut stack, + &mut sidetable, &locals, globals, fn_types, @@ -69,15 +72,15 @@ pub fn validate_code_section( ) } - Ok(func_block) + Ok((func_block, sidetable)) })?; trace!( "Read code section. Found {} code blocks", - code_block_spans.len() + code_block_spans_sidetables.len() ); - Ok(code_block_spans) + Ok(code_block_spans_sidetables) } pub fn read_declared_locals(wasm: &mut WasmReader) -> Result> { @@ -97,11 +100,79 @@ pub fn read_declared_locals(wasm: &mut WasmReader) -> Result> { Ok(locals) } +//helper function to avoid code duplication in jump validations +//the entries, except for the loop label, need to be correctly backpatched later +//the temporary values of fields (delta_pc, delta_stp) of the entries are the (ip, stp) of the relevant label +//the label is also updated with the additional information of the index of this sidetable +//entry itself so that the entry can be backpatched when the end instruction of the label +//is hit. +fn generate_unbackpatched_sidetable_entry( + wasm: &WasmReader, + sidetable: &mut Sidetable, + valcnt: usize, + popcnt: usize, + label_info: &mut LabelInfo, +) { + let stp_here = sidetable.len(); + + sidetable.push(SidetableEntry { + delta_pc: wasm.pc as isize, + delta_stp: stp_here as isize, + popcnt, + valcnt, + }); + + match label_info { + LabelInfo::Block { stps_to_backpatch } => stps_to_backpatch.push(stp_here), + LabelInfo::Loop { ip, stp } => { + //we already know where to jump to for loops + sidetable[stp_here].delta_pc = *ip as isize - wasm.pc as isize; + sidetable[stp_here].delta_stp = *stp as isize - stp_here as isize; + } + LabelInfo::If { + stps_to_backpatch, .. + } => stps_to_backpatch.push(stp_here), + LabelInfo::Func { stps_to_backpatch } => stps_to_backpatch.push(stp_here), + LabelInfo::Untyped => { + unreachable!("this label is for untyped wasm sequences") + } + } +} + +//helper function to avoid code duplication for common stuff in br, br_if, return +fn validate_intrablock_jump_and_generate_sidetable_entry( + wasm: &WasmReader, + label_idx: usize, + stack: &mut ValidationStack, + sidetable: &mut Sidetable, +) -> Result<()> { + let ctrl_stack_len = stack.ctrl_stack.len(); + + stack.assert_val_types_of_label_jump_types_on_top(label_idx)?; + + let targeted_ctrl_block_entry = stack + .ctrl_stack + .get(ctrl_stack_len - label_idx - 1) + .ok_or(Error::InvalidLabelIdx(label_idx))?; + + let valcnt = targeted_ctrl_block_entry.label_types().len(); + let popcnt = stack.len() - targeted_ctrl_block_entry.height - valcnt; + + let label_info = &mut stack + .ctrl_stack + .get_mut(ctrl_stack_len - label_idx - 1) + .unwrap() + .label_info; + + generate_unbackpatched_sidetable_entry(wasm, sidetable, valcnt, popcnt, label_info); + Ok(()) +} + #[allow(clippy::too_many_arguments)] fn read_instructions( - this_function_idx: usize, wasm: &mut WasmReader, stack: &mut ValidationStack, + sidetable: &mut Sidetable, locals: &[ValType], globals: &[Global], fn_types: &[FuncType], @@ -112,7 +183,6 @@ fn read_instructions( elements: &[ElemType], referenced_functions: &BTreeSet, ) -> Result<()> { - // TODO we must terminate only if both we saw the final `end` and when we consumed all of the code span loop { let Ok(first_instr_byte) = wasm.read_u8() else { // TODO only do this if EOF @@ -122,66 +192,166 @@ fn read_instructions( use crate::core::reader::types::opcode::*; match first_instr_byte { - // nop + // nop: [] -> [] NOP => {} - // end - END => { - // TODO check if there are labels on the stack. - // If there are none (i.e. this is the implicit end of the function and not a jump to the end of a function), the stack must only contain the valid return values, no other junk. - // - // Else, anything may remain on the stack, as long as the top of the stack matche the current blocks return value. - - if stack.has_remaining_label() { - // This is the END of a block. + // block: [] -> [t*2] + BLOCK => { + let block_ty = BlockType::read(wasm)?.as_func_type(fn_types)?; + let label_info = LabelInfo::Block { + stps_to_backpatch: Vec::new(), + }; + stack.assert_push_ctrl(label_info, block_ty)?; + } + LOOP => { + let block_ty = BlockType::read(wasm)?.as_func_type(fn_types)?; + let label_info = LabelInfo::Loop { + ip: wasm.pc, + stp: sidetable.len(), + }; + stack.assert_push_ctrl(label_info, block_ty)?; + } + IF => { + let block_ty = BlockType::read(wasm)?.as_func_type(fn_types)?; - // We check the valtypes on top of the stack + stack.assert_pop_val_type(ValType::NumType(NumType::I32))?; - // TODO remove the ugly hack for the todo!(..)! - #[allow(clippy::diverging_sub_expression)] - { - let _block_return_ty: &[ValType] = - todo!("get return types for current block"); + let stp_here = sidetable.len(); + sidetable.push(SidetableEntry { + delta_pc: wasm.pc as isize, + delta_stp: stp_here as isize, + popcnt: 0, + valcnt: block_ty.params.valtypes.len(), + }); + + let label_info = LabelInfo::If { + stp: stp_here, + stps_to_backpatch: Vec::new(), + }; + stack.assert_push_ctrl(label_info, block_ty)?; + } + ELSE => { + let (mut label_info, block_ty) = stack.assert_pop_ctrl()?; + if let LabelInfo::If { + stp, + stps_to_backpatch, + } = &mut label_info + { + if *stp == usize::MAX { + //this If was previously matched with an else already, it is already backpatched! + return Err(Error::IfWithoutMatchingElse); + } + let stp_here = sidetable.len(); + sidetable.push(SidetableEntry { + delta_pc: wasm.pc as isize, + delta_stp: stp_here as isize, + popcnt: 0, + valcnt: block_ty.returns.valtypes.len(), + }); + stps_to_backpatch.push(stp_here); + + sidetable[*stp].delta_pc = wasm.pc as isize - sidetable[*stp].delta_pc; + sidetable[*stp].delta_stp = + sidetable.len() as isize - sidetable[*stp].delta_stp; + + *stp = usize::MAX; // mark this If as backpatched + + for valtype in block_ty.returns.valtypes.iter().rev() { + stack.assert_pop_val_type(*valtype)?; } - // stack.assert_val_types_on_top(block_return_ty)?; - // Clear the stack until the next label - // stack.clear_until_next_label(); + for valtype in block_ty.params.valtypes.iter() { + stack.push_valtype(*valtype); + } - // And push the blocks return types onto the stack again - // for valtype in block_return_ty { - // stack.push_valtype(*valtype); - // } + stack.assert_push_ctrl(label_info, block_ty)?; } else { - // This is the last end of a function - - // The stack must only contain the function's return valtypes - let this_func_ty = &fn_types[type_idx_of_fn[this_function_idx]]; - stack.assert_val_types(&this_func_ty.returns.valtypes)?; - return Ok(()); + return Err(Error::ElseWithoutMatchingIf); } } - RETURN => { - let this_func_ty = &fn_types[type_idx_of_fn[this_function_idx]]; + BR => { + let label_idx = wasm.read_var_u32()? as LabelIdx; + validate_intrablock_jump_and_generate_sidetable_entry( + wasm, label_idx, stack, sidetable, + )?; + stack.make_unspecified()?; + } + BR_IF => { + let label_idx = wasm.read_var_u32()? as LabelIdx; + stack.assert_pop_val_type(ValType::NumType(NumType::I32))?; + validate_intrablock_jump_and_generate_sidetable_entry( + wasm, label_idx, stack, sidetable, + )?; + } + BR_TABLE => { + let label_vec = wasm.read_vec(|wasm| wasm.read_var_u32().map(|v| v as LabelIdx))?; + let max_label_idx = wasm.read_var_u32()? as LabelIdx; + stack.assert_pop_val_type(ValType::NumType(NumType::I32))?; - stack - .assert_val_types_on_top(&this_func_ty.returns.valtypes) - .map_err(|_| Error::EndInvalidValueStack)?; + for label_idx in label_vec { + validate_intrablock_jump_and_generate_sidetable_entry( + wasm, label_idx, stack, sidetable, + )?; + } - stack.make_unspecified(); + validate_intrablock_jump_and_generate_sidetable_entry( + wasm, + max_label_idx, + stack, + sidetable, + )?; - // TODO(george-cosma): a `return Ok(());` should probably be introduced here, but since we don't have - // controls flows implemented, the only way to test `return` is to place it at the end of function. - // However, an `end` is introduced after it, which is invalid. Compilation for this test case should - // probably fail. + stack.make_unspecified()?; + } + END => { + let (label_info, _) = stack.assert_pop_ctrl()?; + let stp_here = sidetable.len(); + + match label_info { + LabelInfo::Block { stps_to_backpatch } => { + stps_to_backpatch.iter().for_each(|i| { + sidetable[*i].delta_pc = (wasm.pc as isize) - sidetable[*i].delta_pc; + sidetable[*i].delta_stp = (stp_here as isize) - sidetable[*i].delta_stp; + }); + } + LabelInfo::If { + stp, + stps_to_backpatch, + } => { + if stp != usize::MAX { + //This If is still not backpatched, meaning it does not have a corresponding + //ELSE. Therefore if its condition fails, it jumps after END. + sidetable[stp].delta_pc = (wasm.pc as isize) - sidetable[stp].delta_pc; + sidetable[stp].delta_stp = + (stp_here as isize) - sidetable[stp].delta_stp; + } + stps_to_backpatch.iter().for_each(|i| { + sidetable[*i].delta_pc = (wasm.pc as isize) - sidetable[*i].delta_pc; + sidetable[*i].delta_stp = (stp_here as isize) - sidetable[*i].delta_stp; + }); + } + LabelInfo::Loop { .. } => (), + LabelInfo::Func { stps_to_backpatch } => { + // same as blocks, except jump just before the end instr, not after it + // the last end instruction will handle the return to callee during execution + stps_to_backpatch.iter().for_each(|i| { + sidetable[*i].delta_pc = + (wasm.pc as isize) - sidetable[*i].delta_pc - 1; // minus 1 is important! + sidetable[*i].delta_stp = (stp_here as isize) - sidetable[*i].delta_stp; + }); + } + LabelInfo::Untyped => unreachable!("this label is for untyped wasm sequences"), + } - // TODO(wucke13) I believe we must not drain the validation stack here; only if we - // know this return is actually taken during execution we may drain the stack. This - // could however be a conditional return (return in an `if`), and the other side - // past the `else` might need the values on the `ValidationStack` that do belong - // to the current function (but not the current block), so draining would make - // continued validation of the current function impossible. We should most - // definitely not `return Ok(())` here, because there might be still more of the - // current function to validate. + if stack.ctrl_stack.is_empty() { + return Ok(()); + } + } + RETURN => { + let label_idx = stack.ctrl_stack.len() - 1; // return behaves the same as br + validate_intrablock_jump_and_generate_sidetable_entry( + wasm, label_idx, stack, sidetable, + )?; + stack.make_unspecified()?; } // call [t1*] -> [t2*] CALL => { @@ -229,7 +399,7 @@ fn read_instructions( } // unreachable: [t1*] -> [t2*] UNREACHABLE => { - stack.make_unspecified(); + stack.make_unspecified()?; } DROP => { stack.drop_val()?; @@ -250,7 +420,7 @@ fn read_instructions( LOCAL_TEE => { let local_idx = wasm.read_var_u32()? as LocalIdx; let local_ty = locals.get(local_idx).ok_or(Error::InvalidLocalIdx)?; - stack.assert_pop_val_type(*local_ty)?; + stack.assert_val_types_on_top(&[*local_ty])?; } // global.get [] -> [t] GLOBAL_GET => { diff --git a/src/validation/mod.rs b/src/validation/mod.rs index 024f4a19..06c4b64e 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -11,6 +11,7 @@ use crate::core::reader::types::global::Global; use crate::core::reader::types::import::Import; use crate::core::reader::types::{FuncType, MemType, TableType}; use crate::core::reader::{WasmReadable, WasmReader}; +use crate::core::sidetable::Sidetable; use crate::{Error, Result}; pub(crate) mod code; @@ -31,7 +32,8 @@ pub struct ValidationInfo<'bytecode> { pub(crate) globals: Vec, #[allow(dead_code)] pub(crate) exports: Vec, - pub(crate) func_blocks: Vec, + /// Each block contains the validated code section and the generated sidetable + pub(crate) func_blocks: Vec<(Span, Sidetable)>, pub(crate) data: Vec, /// The start function which is automatically executed during instantiation pub(crate) start: Option, @@ -145,23 +147,28 @@ pub fn validate(wasm: &[u8]) -> Result { while (skip_section(&mut wasm, &mut header)?).is_some() {} - let func_blocks = handle_section(&mut wasm, &mut header, SectionTy::Code, |wasm, h| { - code::validate_code_section( - wasm, - h, - &types, - &functions, - &globals, - &memories, - &data_count, - &tables, - &elements, - &referenced_functions, - ) - })? - .unwrap_or_default(); + let func_blocks_sidetables = + handle_section(&mut wasm, &mut header, SectionTy::Code, |wasm, h| { + code::validate_code_section( + wasm, + h, + &types, + &functions, + &globals, + &memories, + &data_count, + &tables, + &elements, + &referenced_functions, + ) + })? + .unwrap_or_default(); - assert_eq!(func_blocks.len(), functions.len(), "these should be equal"); // TODO check if this is in the spec + assert_eq!( + func_blocks_sidetables.len(), + functions.len(), + "these should be equal" + ); // TODO check if this is in the spec while (skip_section(&mut wasm, &mut header)?).is_some() {} @@ -192,7 +199,7 @@ pub fn validate(wasm: &[u8]) -> Result { memories, globals, exports, - func_blocks, + func_blocks: func_blocks_sidetables, data: data_section, start, elements, diff --git a/src/validation/validation_stack.rs b/src/validation/validation_stack.rs index 6c2d7fa8..0863e0f3 100644 --- a/src/validation/validation_stack.rs +++ b/src/validation/validation_stack.rs @@ -1,23 +1,52 @@ -//! This module contains the [`ValidationStack`] data structure -//! -//! The [`ValidationStack`] is a unified stack, in the sense that it unifies both -//! [`ValidationStackEntry::Val`] and [`ValidationStackEntry::Label`]. It therefore mixes type -//! information with structured control flow information. -#![allow(unused)] // TODO remove this once sidetable implementation lands use super::Result; +use alloc::vec; use alloc::vec::Vec; -use crate::{Error, RefType, ValType}; +use crate::{ + core::reader::types::{FuncType, ResultType}, + Error, RefType, ValType, +}; #[derive(Debug, PartialEq, Eq)] pub struct ValidationStack { stack: Vec, + // TODO hide implementation + pub ctrl_stack: Vec, } impl ValidationStack { /// Initialize a new ValidationStack pub fn new() -> Self { - Self { stack: Vec::new() } + Self { + stack: Vec::new(), + ctrl_stack: vec![CtrlStackEntry { + label_info: LabelInfo::Untyped, + block_ty: FuncType { + params: ResultType { + valtypes: Vec::new(), + }, + returns: ResultType { + valtypes: Vec::new(), + }, + }, + height: 0, + unreachable: false, + }], + } + } + + pub(super) fn new_for_func(block_ty: FuncType) -> Self { + Self { + stack: Vec::new(), + ctrl_stack: vec![CtrlStackEntry { + label_info: LabelInfo::Func { + stps_to_backpatch: Vec::new(), + }, + block_ty, + height: 0, + unreachable: false, + }], + } } pub fn len(&self) -> usize { @@ -28,41 +57,26 @@ impl ValidationStack { self.stack.push(ValidationStackEntry::Val(valtype)); } - pub fn push_label(&mut self, label_info: LabelInfo) { - self.stack.push(ValidationStackEntry::Label(label_info)); - } - - pub fn peek_stack(&self) -> Option { + /// DANGER! only to be used within const validation! use within non-const validation may result in algorithmically incorrect validation + pub fn peek_const_validation_stack(&self) -> Option { self.stack.last().cloned() } - /// Similar to [`ValidationStack::pop`], because it pops a value from the stack, + /// Similar to [`ValidationStack::pop_valtype`], because it pops a value from the stack, /// but more public and doesn't actually return the popped value. - pub fn drop_val(&mut self) -> Result<()> { - match self.stack.pop().ok_or(Error::EndInvalidValueStack)? { - ValidationStackEntry::Val(_) => Ok(()), - _ => Err(Error::ExpectedAnOperand), - } + pub(super) fn drop_val(&mut self) -> Result<()> { + self.pop_valtype().map_err(|_| Error::ExpectedAnOperand)?; + Ok(()) } - /// This puts an unspecified element on top of the stack. - /// While the top of the stack is unspecified, arbitrary value types can be popped. - /// To undo this, a new label has to be pushed or an existing one has to be popped. - /// - /// See the documentation for [`ValidationStackEntry::UnspecifiedValTypes`] for more info. - pub fn make_unspecified(&mut self) { - // Pop everything until next label or until the stack is empty. - // This is okay, because these values cannot be accessed during execution ever again. - while let Some(entry) = self.stack.last() { - match entry { - ValidationStackEntry::Val(_) | ValidationStackEntry::UnspecifiedValTypes => { - self.stack.pop(); - } - ValidationStackEntry::Label(_) => break, - } - } - - self.stack.push(ValidationStackEntry::UnspecifiedValTypes) + pub(super) fn make_unspecified(&mut self) -> Result<()> { + let last_ctrl_stack_entry = self + .ctrl_stack + .last_mut() + .ok_or(Error::ValidationCtrlStackEmpty)?; + last_ctrl_stack_entry.unreachable = true; + self.stack.truncate(last_ctrl_stack_entry.height); + Ok(()) } /// Pop a [`ValidationStackEntry`] from the [`ValidationStack`] @@ -72,92 +86,109 @@ impl ValidationStack { /// - Returns `Ok(_)` with the former top-most [`ValidationStackEntry`] inside, if the stack had /// at least one element. /// - Returns `Err(_)` if the stack was already empty. - fn pop(&mut self) -> Result { - self.stack - .pop() - .ok_or(Error::InvalidValidationStackValType(None)) + fn pop_valtype(&mut self) -> Result { + // TODO unwrapping might not be the best option + // TODO ugly + // TODO return type should be Result<()> maybe? + let last_ctrl_stack_entry = self.ctrl_stack.last().unwrap(); + assert!(self.stack.len() >= last_ctrl_stack_entry.height); + if last_ctrl_stack_entry.height == self.stack.len() { + if last_ctrl_stack_entry.unreachable { + Ok(ValidationStackEntry::UnspecifiedValTypes) + } else { + Err(Error::EndInvalidValueStack) + } + } else { + //empty stack is covered with above check + self.stack.pop().ok_or(Error::EndInvalidValueStack) + } } pub fn assert_pop_ref_type(&mut self, expected_ty: Option) -> Result<()> { - let val = self.pop()?; - match val { - ValidationStackEntry::Val(v) => match v { - ValType::RefType(ref_type) => match expected_ty { - None => Ok(()), - Some(expected_ty) => { - if expected_ty == ref_type { - Ok(()) - } else { - Err(Error::DifferentRefTypes(ref_type, expected_ty)) - } - } - }, - _ => Err(Error::ExpectedARefType(v)), - }, - ValidationStackEntry::UnspecifiedValTypes => Err(Error::FoundUnspecifiedValTypes), - ValidationStackEntry::Label(li) => Err(Error::FoundLabel(li.kind)), + match self.pop_valtype()? { + ValidationStackEntry::Val(ValType::RefType(ref_type)) => { + expected_ty.map_or(Ok(()), |ty| { + (ty == ref_type) + .then_some(()) + .ok_or(Error::DifferentRefTypes(ref_type, ty)) + }) + } + ValidationStackEntry::Val(v) => Err(Error::ExpectedARefType(v)), + // TODO fix the thrown error type below + ValidationStackEntry::NumOrVecType => Err(Error::EndInvalidValueStack), + ValidationStackEntry::UnspecifiedValTypes => Ok(()), } } /// Assert the top-most [`ValidationStackEntry`] is a specific [`ValType`], after popping it from the [`ValidationStack`] + /// This assertion will unify the the top-most entry with `expected_ty`. /// /// # Returns /// /// - Returns `Ok(())` if the top-most [`ValidationStackEntry`] is a [`ValType`] identical to /// `expected_ty`. /// - Returns `Err(_)` otherwise. + /// pub fn assert_pop_val_type(&mut self, expected_ty: ValType) -> Result<()> { - if let Some(ValidationStackEntry::UnspecifiedValTypes) = self.stack.last() { - // An unspecified value is always correct, and will never disappear by popping. - return Ok(()); - } - - match self.pop()? { + match self.pop_valtype()? { ValidationStackEntry::Val(ty) => (ty == expected_ty) .then_some(()) .ok_or(Error::InvalidValidationStackValType(Some(ty))), - ValidationStackEntry::Label(li) => Err(Error::FoundLabel(li.kind)), - ValidationStackEntry::UnspecifiedValTypes => { - unreachable!("we just checked if the topmost entry is of this type") - } + ValidationStackEntry::NumOrVecType => match expected_ty { + ValType::NumType(_) => Ok(()), + ValType::VecType => Ok(()), + // TODO change this error + _ => Err(Error::InvalidValidationStackValType(None)), + }, + ValidationStackEntry::UnspecifiedValTypes => Ok(()), } } - /// Asserts that the values on top of the stack match those of a value iterator - /// - /// The last element of `expected_val_types` is compared to the top-most - /// [`ValidationStackEntry`], the second last `expected_val_types` element to the second top-most - /// [`ValidationStackEntry`] etc. - /// - /// Any occurence of the [`ValidationStackEntry::Label`] variant in the stack tail will cause an - /// error. This method does not mutate the [`ValidationStack::stack`] in any way. - /// - /// # Returns - /// - /// - `Ok(_)`, the tail of the stack matches the `expected_val_types` - /// - `Err(_)` otherwise - pub fn assert_val_types_on_top(&self, expected_val_types: &[ValType]) -> Result<()> { - let stack_tail = self - .stack - .get(self.stack.len() - expected_val_types.len()..) - .ok_or(Error::InvalidValType)?; - - // Now we check the valtypes in reverse. - // That way we can stop checking if we encounter an `UnspecifiedValTypes`. - - let mut current_expected_valtype = expected_val_types.iter().rev(); - for entry in stack_tail.iter().rev() { - match entry { - ValidationStackEntry::Label(label) => return Err(Error::EndInvalidValueStack), - ValidationStackEntry::Val(valtype) => { - if Some(valtype) != current_expected_valtype.next() { + // private fns to shut the borrow checker up when calling methods with mutable ref to self with immutable ref to self arguments + // TODO ugly but I can't come up with anything else better + + fn assert_val_types_on_top_with_custom_stacks( + stack: &mut Vec, + ctrl_stack: &[CtrlStackEntry], + expected_val_types: &[ValType], + ) -> Result<()> { + let last_ctrl_stack_entry = ctrl_stack.last().ok_or(Error::ValidationCtrlStackEmpty)?; + let stack_len = stack.len(); + + let rev_iterator = expected_val_types.iter().rev().enumerate(); + for (i, expected_ty) in rev_iterator { + if stack_len - last_ctrl_stack_entry.height <= i { + if last_ctrl_stack_entry.unreachable { + // Unify(t2*,expected_val_types) := [t2* expected_val_types] + stack.splice( + stack_len - i..stack_len - i, + expected_val_types[..expected_val_types.len() - i] + .iter() + .map(|ty| ValidationStackEntry::Val(*ty)), + ); + return Ok(()); + } else { + return Err(Error::EndInvalidValueStack); + } + } + + // the above height check ensures this access is valid + let actual_ty = &mut stack[stack_len - i - 1]; + + match actual_ty { + ValidationStackEntry::Val(actual_val_ty) => { + if *actual_val_ty != *expected_ty { return Err(Error::EndInvalidValueStack); } } + ValidationStackEntry::NumOrVecType => match expected_ty { + // unify the NumOrVecType to expected_ty + ValType::NumType(_) => *actual_ty = ValidationStackEntry::Val(*expected_ty), + ValType::VecType => *actual_ty = ValidationStackEntry::Val(*expected_ty), + _ => return Err(Error::EndInvalidValueStack), + }, ValidationStackEntry::UnspecifiedValTypes => { - // In case we find an `UnspecifiedValTypes`, we pretend that all expected valtypes are found. - // That's because this entry can expand to every possible combination of valtypes. - return Ok(()); + unreachable!("bottom type should not exist in the stack") } } } @@ -165,8 +196,50 @@ impl ValidationStack { Ok(()) } - /// Asserts that the valtypes on the stack match the expected valtypes. + fn assert_val_types_with_custom_stacks( + stack: &mut Vec, + ctrl_stack: &[CtrlStackEntry], + expected_val_types: &[ValType], + ) -> Result<()> { + ValidationStack::assert_val_types_on_top_with_custom_stacks( + stack, + ctrl_stack, + expected_val_types, + )?; + //if we can assert types in the above there is a last ctrl stack entry, this access is valid. + let last_ctrl_stack_entry = &ctrl_stack[ctrl_stack.len() - 1]; + if stack.len() == last_ctrl_stack_entry.height + expected_val_types.len() { + Ok(()) + } else { + Err(Error::EndInvalidValueStack) + } + } + /// Asserts that the values on top of the stack match those of a value iterator + /// This method will unify the types on the stack to the expected valtypes. + /// The last element of `expected_val_types` is unified to the top-most + /// [`ValidationStackEntry`], the second last `expected_val_types` element to the second top-most + /// [`ValidationStackEntry`] etc. + /// + /// Any unification failure or arity mismatch will cause an error. + /// + /// Any occurence of an error may leave the stack in an invalid state. + /// + /// # Returns + /// + /// - `Ok(_)`, the tail of the stack matches the `expected_val_types` + /// - `Err(_)` otherwise /// + pub(super) fn assert_val_types_on_top(&mut self, expected_val_types: &[ValType]) -> Result<()> { + ValidationStack::assert_val_types_on_top_with_custom_stacks( + &mut self.stack, + &self.ctrl_stack, + expected_val_types, + ) + } + + // TODO better documentation + /// Asserts that the valtypes on the stack match the expected valtypes and no other type is on the stack. + /// This method will unify the types on the stack to the expected valtypes. /// This starts by comparing the top-most valtype with the last element from `expected_val_types` and then continues downwards on the stack. /// If a label is reached and not all `expected_val_types` have been checked, the assertion fails. /// @@ -174,77 +247,60 @@ impl ValidationStack { /// /// - `Ok(())` if all expected valtypes were found /// - `Err(_)` otherwise - pub fn assert_val_types(&self, expected_val_types: &[ValType]) -> Result<()> { - let topmost_label_index = self.find_topmost_label_idx(); - - let first_valtype = topmost_label_index.map(|idx| idx + 1).unwrap_or(0); - - // Now we check the valtypes in reverse. - // That way we can stop checking if we encounter an `UnspecifiedValTypes`. - - let mut current_expected_valtype = expected_val_types.iter().rev(); - for entry in self.stack[first_valtype..].iter().rev() { - match entry { - ValidationStackEntry::Label(_) => unreachable!( - "we started at the top-most label so we cannot find any more labels" - ), - ValidationStackEntry::Val(valtype) => { - if Some(valtype) != current_expected_valtype.next() { - return Err(Error::EndInvalidValueStack); - } - } - ValidationStackEntry::UnspecifiedValTypes => { - return Ok(()); - } - } - } - - Ok(()) + pub(super) fn assert_val_types(&mut self, expected_val_types: &[ValType]) -> Result<()> { + ValidationStack::assert_val_types_with_custom_stacks( + &mut self.stack, + &self.ctrl_stack, + expected_val_types, + ) } - /// A helper to find the index of the top-most label in [`ValidationStack::stack`] - fn find_topmost_label_idx(&self) -> Option { - self.stack - .iter() - .enumerate() - .rev() - .find(|(_idx, entry)| matches!(entry, ValidationStackEntry::Label(_))) - .map(|(idx, _entry)| idx) + pub fn assert_val_types_of_label_jump_types_on_top(&mut self, label_idx: usize) -> Result<()> { + let label_types = self + .ctrl_stack + .get(self.ctrl_stack.len() - label_idx - 1) + .ok_or(Error::InvalidLabelIdx(label_idx))? + .label_types(); + ValidationStack::assert_val_types_on_top_with_custom_stacks( + &mut self.stack, + &self.ctrl_stack, + label_types, + ) } - /// Searches for the top-most label, then pops the label and all entry on top of that label. - /// Only the label's [`LabelInfo`] is returned. - /// - /// # Returns - /// - /// - `Ok(LabelInfo)` if a label has been found and popped - /// - `None` if no label was found on the stack - fn pop_label_and_above(&mut self) -> Option { - /// Delete all the values until the topmost label or until the stack is empty - match self.find_topmost_label_idx() { - Some(idx) => { - if self.stack.len() > idx + 1 { - self.stack.drain((idx + 1)..); - } - } - None => self.stack.clear(), - } - - // Pop the label itself - match self.pop() { - Ok(ValidationStackEntry::Label(info)) => Some(info), - Ok(_) => unreachable!( - "we just removed everything until the next label, thus new topmost entry must be a label" - ), - Err(_) => None, - } + // TODO is moving block_ty ok? + pub fn assert_push_ctrl(&mut self, label_info: LabelInfo, block_ty: FuncType) -> Result<()> { + self.assert_val_types_on_top(&block_ty.params.valtypes)?; + let height = self.stack.len() - block_ty.params.valtypes.len(); + self.ctrl_stack.push(CtrlStackEntry { + label_info, + block_ty, + height, + unreachable: false, + }); + Ok(()) } - /// Return true if the stack has at least one remaining label - pub fn has_remaining_label(&self) -> bool { - self.stack - .iter() - .any(|e| matches!(e, ValidationStackEntry::Label(_))) + pub fn assert_pop_ctrl(&mut self) -> Result<(LabelInfo, FuncType)> { + let return_types = &self + .ctrl_stack + .last() + .ok_or(Error::ValidationCtrlStackEmpty)? + .block_ty + .returns + .valtypes; + ValidationStack::assert_val_types_with_custom_stacks( + &mut self.stack, + &self.ctrl_stack, + return_types, + )?; + + //if we can assert types in the above there is a last ctrl stack entry, this access is valid. + let last_ctrl_stack_entry = self.ctrl_stack.pop().unwrap(); + Ok(( + last_ctrl_stack_entry.label_info, + last_ctrl_stack_entry.block_ty, + )) } } @@ -252,36 +308,80 @@ impl ValidationStack { pub enum ValidationStackEntry { /// A value Val(ValType), - - /// A label - Label(LabelInfo), - + /// Special variant to encode an uninstantiated type for `select` instruction + #[allow(unused)] + NumOrVecType, /// Special variant to encode that any possible number of [`ValType`]s could be here /// /// Caused by `return` and `unreachable`, as both can push an arbitrary number of values to the stack. /// /// When this variant is pushed onto the stack, all valtypes until the next lower label are deleted. /// They are not needed anymore because this variant can expand to all of them. + // TODO change this name to BottomType UnspecifiedValTypes, } - +// TODO hide implementation #[derive(Clone, Debug, PartialEq, Eq)] -pub struct LabelInfo { - pub kind: LabelKind, +pub struct CtrlStackEntry { + pub label_info: LabelInfo, + pub block_ty: FuncType, + pub height: usize, + pub unreachable: bool, } +impl CtrlStackEntry { + pub fn label_types(&self) -> &[ValType] { + if matches!(self.label_info, LabelInfo::Loop { .. }) { + &self.block_ty.params.valtypes + } else { + &self.block_ty.returns.valtypes + } + } +} + +// TODO replace LabelInfo with this +// TODO hide implementation +// TODO implementation coupled to Sidetable #[derive(Clone, Debug, PartialEq, Eq)] -pub enum LabelKind { - Block, - Loop, - If, +pub enum LabelInfo { + Block { + stps_to_backpatch: Vec, + }, + Loop { + ip: usize, + stp: usize, + }, + If { + stps_to_backpatch: Vec, + stp: usize, + }, + Func { + stps_to_backpatch: Vec, + }, + Untyped, } #[cfg(test)] mod tests { use crate::{NumType, RefType, ValType}; - use super::{LabelInfo, LabelKind, ValidationStack}; + use super::{CtrlStackEntry, FuncType, LabelInfo, ResultType, ValidationStack, Vec}; + + fn push_dummy_untyped_label(validation_stack: &mut ValidationStack) { + validation_stack.ctrl_stack.push(CtrlStackEntry { + label_info: LabelInfo::Untyped, + block_ty: FuncType { + params: ResultType { + valtypes: Vec::new(), + }, + returns: ResultType { + valtypes: Vec::new(), + }, + }, + height: validation_stack.len(), + unreachable: false, + }) + } #[test] fn push_then_pop() { @@ -304,41 +404,38 @@ mod tests { .unwrap(); } - #[test] - fn labels() { - let mut stack = ValidationStack::new(); + // TODO rewrite these + // #[test] + // fn labels() { + // let mut stack = ValidationStack::new(); - stack.push_valtype(ValType::NumType(NumType::I64)); - stack.push_label(LabelInfo { - kind: LabelKind::Block, - }); + // stack.push_valtype(ValType::NumType(NumType::I64)); + // push_dummy_func_label(&mut stack); - stack.push_label(LabelInfo { - kind: LabelKind::Loop, - }); + // push_dummy_block_label(&mut stack); - stack.push_valtype(ValType::VecType); + // stack.push_valtype(ValType::VecType); - // This removes the `ValType::VecType` and the `LabelKind::Loop` label - let popped_label = stack.pop_label_and_above().unwrap(); - assert_eq!( - popped_label, - LabelInfo { - kind: LabelKind::Loop, - } - ); + // // This removes the `ValType::VecType` and the `LabelKind::Loop` label + // let popped_label = stack.pop_label_and_above().unwrap(); + // assert_eq!( + // popped_label, + // LabelInfo { + // kind: LabelKind::Loop, + // } + // ); - let popped_label = stack.pop_label_and_above().unwrap(); - assert_eq!( - popped_label, - LabelInfo { - kind: LabelKind::Block, - } - ); + // let popped_label = stack.pop_label_and_above().unwrap(); + // assert_eq!( + // popped_label, + // LabelInfo { + // kind: LabelKind::Block, + // } + // ); - // The first valtype should still be there - stack.assert_pop_val_type(ValType::NumType(NumType::I64)); - } + // // The first valtype should still be there + // stack.assert_pop_val_type(ValType::NumType(NumType::I64)); + // } #[test] fn assert_valtypes() { @@ -356,9 +453,8 @@ mod tests { ]) .unwrap(); - stack.push_label(LabelInfo { - kind: LabelKind::Block, - }); + push_dummy_untyped_label(&mut stack); + stack.push_valtype(ValType::NumType(NumType::I32)); stack @@ -373,9 +469,7 @@ mod tests { stack.assert_val_types(&[]).unwrap(); stack.push_valtype(ValType::NumType(NumType::I32)); - stack.push_label(LabelInfo { - kind: LabelKind::Block, - }); + push_dummy_untyped_label(&mut stack); // Valtypes separated by a label should also not be detected stack.assert_val_types(&[]).unwrap(); @@ -417,11 +511,9 @@ mod tests { #[test] fn unspecified() { let mut stack = ValidationStack::new(); - stack.push_label(LabelInfo { - kind: LabelKind::Block, - }); + push_dummy_untyped_label(&mut stack); - stack.make_unspecified(); + stack.make_unspecified().unwrap(); // Now we can pop as many valtypes from the stack as we want stack @@ -433,15 +525,125 @@ mod tests { .unwrap(); // Let's remove the unspecified entry and the first label - let popped_label = stack.pop_label_and_above().unwrap(); - assert_eq!( - popped_label, - LabelInfo { - kind: LabelKind::Block, - } - ); + + // TODO hide implementation + stack.ctrl_stack.pop(); // Now there are no values left on the stack assert_eq!(stack.assert_val_types(&[]), Ok(())); } + + #[test] + fn unspecified2() { + let mut stack = ValidationStack::new(); + push_dummy_untyped_label(&mut stack); + + stack.make_unspecified().unwrap(); + + // Stack needs to keep track of unified types, I64 and F32 and I32 will appear. + stack + .assert_val_types(&[ + ValType::NumType(NumType::I64), + ValType::NumType(NumType::F32), + ValType::NumType(NumType::I32), + ]) + .unwrap(); + + stack.ctrl_stack.pop(); + + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::I32)), + Ok(()) + ); + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::F32)), + Ok(()) + ); + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::I64)), + Ok(()) + ); + } + + #[test] + fn unspecified3() { + let mut stack = ValidationStack::new(); + push_dummy_untyped_label(&mut stack); + + stack.make_unspecified().unwrap(); + + stack.push_valtype(ValType::NumType(NumType::I32)); + + // Stack needs to keep track of unified types, I64 and F32 will appear under I32. + // Stack needs to keep track of unified types, I64 and F32 and I32 will appear. + stack + .assert_val_types(&[ + ValType::NumType(NumType::I64), + ValType::NumType(NumType::F32), + ValType::NumType(NumType::I32), + ]) + .unwrap(); + + stack.ctrl_stack.pop(); + + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::I32)), + Ok(()) + ); + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::F32)), + Ok(()) + ); + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::I64)), + Ok(()) + ); + } + + #[test] + fn unspecified4() { + let mut stack = ValidationStack::new(); + + stack.push_valtype(ValType::VecType); + stack.push_valtype(ValType::NumType(NumType::I32)); + + push_dummy_untyped_label(&mut stack); + + stack.make_unspecified().unwrap(); + + stack.push_valtype(ValType::VecType); + stack.push_valtype(ValType::RefType(RefType::FuncRef)); + + // Stack needs to keep track of unified types, I64 and F32 will appear below VecType and RefType + // and above I32 and VecType + stack + .assert_val_types(&[ + ValType::NumType(NumType::I64), + ValType::NumType(NumType::F32), + ValType::VecType, + ValType::RefType(RefType::FuncRef), + ]) + .unwrap(); + + stack.ctrl_stack.pop(); + + assert_eq!( + stack.assert_pop_val_type(ValType::RefType(RefType::FuncRef)), + Ok(()) + ); + assert_eq!(stack.assert_pop_val_type(ValType::VecType), Ok(())); + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::F32)), + Ok(()) + ); + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::I64)), + Ok(()) + ); + assert_eq!( + stack.assert_pop_val_type(ValType::NumType(NumType::I32)), + Ok(()) + ); + assert_eq!(stack.assert_pop_val_type(ValType::VecType), Ok(())); + } } diff --git a/tests/lib.rs b/tests/lib.rs index bbbb713a..0aec2d61 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1 +1,2 @@ mod arithmetic; +mod structured_control_flow; diff --git a/tests/structured_control_flow/block.rs b/tests/structured_control_flow/block.rs new file mode 100644 index 00000000..b5d754a2 --- /dev/null +++ b/tests/structured_control_flow/block.rs @@ -0,0 +1,381 @@ +use wasm::{validate, RuntimeInstance}; + +/// Runs a function that does nothing and contains only a single empty block +#[test_log::test] +fn empty() { + let wasm_bytes = wat::parse_str( + r#" + (module + (func (export "do_nothing") (block) + ) + ) + "#, + ) + .unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + (), + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), ()) + .unwrap() + ); +} + +#[test_log::test] +fn branch() { + let wasm_bytes = wat::parse_str( + r#" + (module + (func (export "with_branch") (result i32) + (block $outer_block (result i32) + (block $my_block (result i32) + i32.const 5 + br $my_block + i32.const 3 + ) + i32.const 3 + i32.add + ) + ) + ) + "#, + ) + .unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + 8, + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), ()) + .unwrap() + ); +} + +const BRANCH23_WAT: &str = r#" +(module + (func (export "with_branch") (result i32) + (block $outer_outer_block (result i32) + i64.const 3 + (block $outer_block (param i64) (result i32) (result i32) + drop + i32.const 14 + (block $my_block (result i32) + i32.const 11 + i32.const 8 + i32.const 5 + br {{LABEL}} + i32.const 3 + ) + i32.const 3 + i32.add + ) + drop + i32.const 5 + i32.add + ) + ) +) +"#; + +#[test_log::test] +fn branch2() { + let wat = String::from(BRANCH23_WAT).replace("{{LABEL}}", "$outer_block"); + let wasm_bytes = wat::parse_str(wat).unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + 13, + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), ()) + .unwrap() + ); +} + +#[test_log::test] +fn branch3() { + let wat = String::from(BRANCH23_WAT).replace("{{LABEL}}", "$outer_outer_block"); + let wasm_bytes = wat::parse_str(wat).unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + 5, + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), ()) + .unwrap() + ); +} + +#[test_log::test] +fn param_and_result() { + let wasm_bytes = wat::parse_str( + r#" + (module + (func (export "add_one") (param $x i32) (result i32) + local.get $x + (block $my_block (param i32) (result i32) + i32.const 1 + i32.add + br $my_block + ) + ) + ) + "#, + ) + .unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + 7, + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), 6) + .unwrap() + ); +} + +const RETURN_OUT_OF_BLOCK: &str = r#" +(module + (func (export "get_three") (result i32) + (block + i32.const 5 + i32.const 3 + {{RETURN}} + ) + unreachable + ) +) +"#; + +const RETURN_OUT_OF_BLOCK2: &str = r#" +(module + (func (export "get_three") (result i32) + (block + i32.const 5 + {{RETURN}} + drop + drop + drop + ) + unreachable + ) +) +"#; + +#[test_log::test] +fn return_out_of_block() { + let wat = String::from(RETURN_OUT_OF_BLOCK).replace("{{RETURN}}", "return"); + let wasm_bytes = wat::parse_str(wat).unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + 3, + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), ()) + .unwrap() + ); +} + +#[test_log::test] +fn br_return_out_of_block() { + let wat = String::from(RETURN_OUT_OF_BLOCK).replace("{{RETURN}}", "br 1"); + let wasm_bytes = wat::parse_str(wat).unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + 3, + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), ()) + .unwrap() + ); +} + +#[test_log::test] +fn return_out_of_block2() { + let wat = String::from(RETURN_OUT_OF_BLOCK2).replace("{{RETURN}}", "return"); + let wasm_bytes = wat::parse_str(wat).unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + 5, + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), ()) + .unwrap() + ); +} + +#[test_log::test] +fn br_return_out_of_block2() { + let wat = String::from(RETURN_OUT_OF_BLOCK2).replace("{{RETURN}}", "br 1"); + let wasm_bytes = wat::parse_str(wat).unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + assert_eq!( + 5, + instance + .invoke(&instance.get_function_by_index(0, 0).unwrap(), ()) + .unwrap() + ); +} + +#[test_log::test] +fn branch_if() { + let wasm_bytes = wat::parse_str( + r#" + (module + (func (export "abs") (param $x i32) (result i32) + (block $my_block + local.get $x + i32.const 0 + i32.ge_s + br_if $my_block + local.get $x + i32.const -1 + i32.mul + return + ) + local.get $x + ) + ) + "#, + ) + .unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let switch_case_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(6, instance.invoke(&switch_case_fn, 6).unwrap()); + assert_eq!(123, instance.invoke(&switch_case_fn, -123).unwrap()); + assert_eq!(0, instance.invoke(&switch_case_fn, 0).unwrap()); +} + +#[test_log::test] +fn recursive_fibonacci() { + let wasm_bytes = wat::parse_str( + r#" + (module + (func (export "fibonacci") (param $x i32) (result i32) + (call $fib_internal + (i32.const 0) + (i32.const 1) + (local.get $x) + ) + ) + + (func $fib_internal (param $x0 i32) (param $x1 i32) (param $n_left i32) (result i32) + (block $zero_check + ;; if n_left reached 0, we return + local.get $n_left + br_if $zero_check + local.get $x0 + return + ) + + ;; otherwise decrement n_left + local.get $n_left + i32.const -1 + i32.add + local.set $n_left + + ;; store x1 temporarily + local.get $x1 + + ;; calculate new x1 + local.get $x0 + local.get $x1 + i32.add + local.set $x1 + + ;; set x0 to the previous x1 + local.set $x0 + + + (call $fib_internal + (local.get $x0) + (local.get $x1) + (local.get $n_left) + ) + ) + ) + "#, + ) + .unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let fib_fn = instance.get_function_by_index(0, 0).unwrap(); + + let first_ten = (0..10) + .map(|n| instance.invoke(&fib_fn, n).unwrap()) + .collect::>(); + assert_eq!(&first_ten, &[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]); +} + +#[test_log::test] +fn switch_case() { + let wasm_bytes = wat::parse_str( + r#" + (module + (func $switch_case (param $value i32) (result i32) + (block $default + (block $case4 + (block $case3 + (block $case2 + (block $case1 + local.get $value + (br_table $case1 $case2 $case3 $case4 $default) + ) + i32.const 1 + return + ) + i32.const 3 + return + ) + i32.const 5 + return + ) + i32.const 7 + return + ) + i32.const 9 + return + ) + (export "switch_case" (func $switch_case)) + )"#, + ) + .unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let switch_case_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(9, instance.invoke(&switch_case_fn, -5).unwrap()); + assert_eq!(9, instance.invoke(&switch_case_fn, -1).unwrap()); + assert_eq!(1, instance.invoke(&switch_case_fn, 0).unwrap()); + assert_eq!(3, instance.invoke(&switch_case_fn, 1).unwrap()); + assert_eq!(5, instance.invoke(&switch_case_fn, 2).unwrap()); + assert_eq!(7, instance.invoke(&switch_case_fn, 3).unwrap()); + assert_eq!(9, instance.invoke(&switch_case_fn, 4).unwrap()); + assert_eq!(9, instance.invoke(&switch_case_fn, 7).unwrap()); +} diff --git a/tests/structured_control_flow/if.rs b/tests/structured_control_flow/if.rs new file mode 100644 index 00000000..6f743aae --- /dev/null +++ b/tests/structured_control_flow/if.rs @@ -0,0 +1,168 @@ +use wasm::{validate, RuntimeInstance}; + +#[test_log::test] +fn odd_with_if_else() { + let wasm_bytes = wat::parse_str( + r#" +(module + (func $odd (param $n i32) (result i32) + local.get $n + i32.const 2 + i32.rem_s + (if (result i32) + (then + i32.const 1 + ) + (else + i32.const 0 + ) + ) + ) + + (export "odd" (func $odd)) +)"#, + ) + .unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let odd_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(1, instance.invoke(&odd_fn, -5).unwrap()); + assert_eq!(0, instance.invoke(&odd_fn, 0).unwrap()); + assert_eq!(1, instance.invoke(&odd_fn, 3).unwrap()); + assert_eq!(0, instance.invoke(&odd_fn, 4).unwrap()); +} + +#[test_log::test] +fn odd_with_if() { + let wasm_bytes = wat::parse_str( + r#"(module + (func $odd (param $n i32) (result i32) + local.get $n + i32.const 2 + i32.rem_s + (if + (then + i32.const 1 + return + ) + ) + i32.const 0 + ) + + (export "odd" (func $odd)) +)"#, + ) + .unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let odd_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(1, instance.invoke(&odd_fn, -5).unwrap()); + assert_eq!(0, instance.invoke(&odd_fn, 0).unwrap()); + assert_eq!(1, instance.invoke(&odd_fn, 3).unwrap()); + assert_eq!(0, instance.invoke(&odd_fn, 4).unwrap()); +} + +#[test_log::test] +fn odd_with_if_else_recursive() { + let wasm_bytes = wat::parse_str( + r#" +(module + (func $odd (param $n i32) (result i32) + local.get $n + (if (result i32) + (then + local.get $n + i32.const 1 + i32.sub + call $even + return + ) + (else + i32.const 0 + return + ) + ) + ) + + (func $even (param $n i32) (result i32) + local.get $n + (if (result i32) + (then + local.get $n + i32.const 1 + i32.sub + call $odd + return + ) + (else + i32.const 1 + return + ) + ) + ) + + (export "odd" (func $odd)) +)"#, + ) + .unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let even_odd_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(1, instance.invoke(&even_odd_fn, 1).unwrap()); + assert_eq!(0, instance.invoke(&even_odd_fn, 0).unwrap()); + assert_eq!(1, instance.invoke(&even_odd_fn, 3).unwrap()); + assert_eq!(0, instance.invoke(&even_odd_fn, 4).unwrap()); +} + +#[test_log::test] +fn recursive_fibonacci_if_else() { + let wasm_bytes = wat::parse_str( + r#" +(module + (func $fibonacci (param $n i32) (result i32) + local.get $n + i32.const 1 + i32.le_s + (if (result i32) + (then + i32.const 1 + return + ) + (else + local.get $n + i32.const 1 + i32.sub + call $fibonacci + local.get $n + i32.const 2 + i32.sub + call $fibonacci + i32.add + return + ) + ) + ) + + (export "fibonacci" (func $fibonacci)) +)"#, + ) + .unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let fibonacci_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(1, instance.invoke(&fibonacci_fn, -5).unwrap()); + assert_eq!(1, instance.invoke(&fibonacci_fn, 0).unwrap()); + assert_eq!(1, instance.invoke(&fibonacci_fn, 1).unwrap()); + assert_eq!(2, instance.invoke(&fibonacci_fn, 2).unwrap()); + assert_eq!(3, instance.invoke(&fibonacci_fn, 3).unwrap()); + assert_eq!(5, instance.invoke(&fibonacci_fn, 4).unwrap()); + assert_eq!(8, instance.invoke(&fibonacci_fn, 5).unwrap()); +} diff --git a/tests/structured_control_flow/loop.rs b/tests/structured_control_flow/loop.rs new file mode 100644 index 00000000..87817671 --- /dev/null +++ b/tests/structured_control_flow/loop.rs @@ -0,0 +1,70 @@ +use wasm::{validate, RuntimeInstance}; + +const FIBONACCI_WITH_LOOP_AND_BR_IF: &str = r#" +(module + (func $fibonacci (param $n i32) (result i32) + (local $prev i32) + (local $curr i32) + (local $counter i32) + + i32.const 0 + local.set $prev + i32.const 1 + local.set $curr + + local.get $n + i32.const 1 + i32.add + local.set $counter + + block $exit + loop $loop + local.get $counter + i32.const 1 + i32.le_s + br_if $exit + + local.get $curr + local.get $curr + local.get $prev + i32.add + local.set $curr + local.set $prev + + local.get $counter + i32.const 1 + i32.sub + local.set $counter + + br $loop + + drop + drop + drop + + end $loop + end $exit + + local.get $curr + ) + + (export "fibonacci" (func $fibonacci)) +)"#; + +#[test_log::test] +fn fibonacci_with_loop_and_br_if() { + let wasm_bytes = wat::parse_str(FIBONACCI_WITH_LOOP_AND_BR_IF).unwrap(); + + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let fibonacci_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(1, instance.invoke(&fibonacci_fn, -5).unwrap()); + assert_eq!(1, instance.invoke(&fibonacci_fn, 0).unwrap()); + assert_eq!(1, instance.invoke(&fibonacci_fn, 1).unwrap()); + assert_eq!(2, instance.invoke(&fibonacci_fn, 2).unwrap()); + assert_eq!(3, instance.invoke(&fibonacci_fn, 3).unwrap()); + assert_eq!(5, instance.invoke(&fibonacci_fn, 4).unwrap()); + assert_eq!(8, instance.invoke(&fibonacci_fn, 5).unwrap()); +} diff --git a/tests/structured_control_flow/mod.rs b/tests/structured_control_flow/mod.rs new file mode 100644 index 00000000..59fb5ec3 --- /dev/null +++ b/tests/structured_control_flow/mod.rs @@ -0,0 +1,3 @@ +mod block; +mod r#if; +mod r#loop;