From 451a31e4fe6bd4516ba1a288d89876ea0fc84156 Mon Sep 17 00:00:00 2001 From: Seonghyun Kim Date: Wed, 6 Sep 2023 17:13:30 +0900 Subject: [PATCH] Improve function call performance * Divide function into user-defined function and user-defined function /w try-catch * Non try-catch user-defined function should not use c++ try-catch * Reduce parameter count of Interpreter::interpret Signed-off-by: Seonghyun Kim --- src/interpreter/Interpreter.cpp | 80 ++++++++------------------------- src/interpreter/Interpreter.h | 79 ++++++++++++++++++++++++++++---- src/parser/WASMParser.cpp | 5 ++- src/runtime/Function.cpp | 30 ++++++------- src/runtime/Function.h | 16 +++++++ src/runtime/Instance.cpp | 5 +++ src/runtime/Instance.h | 2 +- src/runtime/Module.cpp | 44 +++++++----------- src/runtime/Module.h | 6 ++- 9 files changed, 148 insertions(+), 119 deletions(-) diff --git a/src/interpreter/Interpreter.cpp b/src/interpreter/Interpreter.cpp index 9e7d8c35c..7f72d30ba 100644 --- a/src/interpreter/Interpreter.cpp +++ b/src/interpreter/Interpreter.cpp @@ -150,53 +150,10 @@ ByteCodeTable::ByteCodeTable() b.m_opcodeInAddress = const_cast(FillByteCodeOpcodeAddress[0]); #endif size_t pc = reinterpret_cast(&b); - Interpreter::interpret(dummyState, pc, nullptr, nullptr, nullptr, nullptr, nullptr); + Interpreter::interpret(dummyState, pc, nullptr, nullptr); #endif } -ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, - uint8_t* bp) -{ - DefinedFunction* df = state.currentFunction()->asDefinedFunction(); - ModuleFunction* mf = df->moduleFunction(); - size_t programCounter = reinterpret_cast(mf->byteCode()); - Instance* instance = df->instance(); - while (true) { - try { - return interpret(state, programCounter, bp, instance, instance->m_memories, instance->m_tables, instance->m_globals); - } catch (std::unique_ptr& e) { - for (size_t i = e->m_programCounterInfo.size(); i > 0; i--) { - if (e->m_programCounterInfo[i - 1].first == &state) { - programCounter = e->m_programCounterInfo[i - 1].second; - break; - } - } - if (e->isUserException()) { - bool isCatchSucessful = false; - Tag* tag = e->tag().value(); - size_t offset = programCounter - reinterpret_cast(mf->byteCode()); - for (const auto& item : mf->catchInfo()) { - if (item.m_tryStart <= offset && offset < item.m_tryEnd) { - if (item.m_tagIndex == std::numeric_limits::max() || state.currentFunction()->asDefinedFunction()->instance()->tag(item.m_tagIndex) == tag) { - programCounter = item.m_catchStartPosition + reinterpret_cast(mf->byteCode()); - uint8_t* sp = bp + item.m_stackSizeToBe; - if (item.m_tagIndex != std::numeric_limits::max() && tag->functionType()->paramStackSize()) { - memcpy(sp, e->userExceptionData().data(), tag->functionType()->paramStackSize()); - } - isCatchSucessful = true; - break; - } - } - } - if (isCatchSucessful) { - continue; - } - } - throw std::unique_ptr(std::move(e)); - } - } -} - template ALWAYS_INLINE void writeValue(uint8_t* bp, ByteCodeStackOffset offset, const T& v) { @@ -488,11 +445,10 @@ static void initAddressToOpcodeTable() ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, size_t programCounter, uint8_t* bp, - Instance* instance, - Memory** memories, - Table** tables, - Global** globals) + Instance* instance) { + Memory** memories = reinterpret_cast(reinterpret_cast(instance) + Instance::alignedSize()); + state.m_programCounterPointer = &programCounter; #define ADD_PROGRAM_COUNTER(codeName) programCounter += sizeof(codeName); @@ -947,7 +903,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { GlobalGet32* code = (GlobalGet32*)programCounter; ASSERT(code->index() < instance->module()->numberOfGlobalTypes()); - globals[code->index()]->value().writeNBytesToMemory<4>(bp + code->dstOffset()); + instance->m_globals[code->index()]->value().writeNBytesToMemory<4>(bp + code->dstOffset()); ADD_PROGRAM_COUNTER(GlobalGet32); NEXT_INSTRUCTION(); } @@ -957,7 +913,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { GlobalGet64* code = (GlobalGet64*)programCounter; ASSERT(code->index() < instance->module()->numberOfGlobalTypes()); - globals[code->index()]->value().writeNBytesToMemory<8>(bp + code->dstOffset()); + instance->m_globals[code->index()]->value().writeNBytesToMemory<8>(bp + code->dstOffset()); ADD_PROGRAM_COUNTER(GlobalGet64); NEXT_INSTRUCTION(); } @@ -967,7 +923,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { GlobalGet128* code = (GlobalGet128*)programCounter; ASSERT(code->index() < instance->module()->numberOfGlobalTypes()); - globals[code->index()]->value().writeNBytesToMemory<16>(bp + code->dstOffset()); + instance->m_globals[code->index()]->value().writeNBytesToMemory<16>(bp + code->dstOffset()); ADD_PROGRAM_COUNTER(GlobalGet128); NEXT_INSTRUCTION(); } @@ -977,7 +933,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { GlobalSet32* code = (GlobalSet32*)programCounter; ASSERT(code->index() < instance->module()->numberOfGlobalTypes()); - Value& val = globals[code->index()]->value(); + Value& val = instance->m_globals[code->index()]->value(); val.readFromStack<4>(bp + code->srcOffset()); ADD_PROGRAM_COUNTER(GlobalSet32); NEXT_INSTRUCTION(); @@ -988,7 +944,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { GlobalSet64* code = (GlobalSet64*)programCounter; ASSERT(code->index() < instance->module()->numberOfGlobalTypes()); - Value& val = globals[code->index()]->value(); + Value& val = instance->m_globals[code->index()]->value(); val.readFromStack<8>(bp + code->srcOffset()); ADD_PROGRAM_COUNTER(GlobalSet64); NEXT_INSTRUCTION(); @@ -999,7 +955,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { GlobalSet128* code = (GlobalSet128*)programCounter; ASSERT(code->index() < instance->module()->numberOfGlobalTypes()); - Value& val = globals[code->index()]->value(); + Value& val = instance->m_globals[code->index()]->value(); val.readFromStack<16>(bp + code->srcOffset()); ADD_PROGRAM_COUNTER(GlobalSet128); NEXT_INSTRUCTION(); @@ -1161,7 +1117,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { TableGet* code = (TableGet*)programCounter; ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes()); - Table* table = tables[code->tableIndex()]; + Table* table = instance->m_tables[code->tableIndex()]; void* val = table->getElement(state, readValue(bp, code->srcOffset())); writeValue(bp, code->dstOffset(), val); @@ -1174,7 +1130,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { TableSet* code = (TableSet*)programCounter; ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes()); - Table* table = tables[code->tableIndex()]; + Table* table = instance->m_tables[code->tableIndex()]; void* ptr = readValue(bp, code->src1Offset()); table->setElement(state, readValue(bp, code->src0Offset()), ptr); @@ -1187,7 +1143,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { TableGrow* code = (TableGrow*)programCounter; ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes()); - Table* table = tables[code->tableIndex()]; + Table* table = instance->m_tables[code->tableIndex()]; size_t size = table->size(); uint64_t newSize = (uint64_t)readValue(bp, code->src1Offset()) + size; @@ -1210,7 +1166,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { TableSize* code = (TableSize*)programCounter; ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes()); - Table* table = tables[code->tableIndex()]; + Table* table = instance->m_tables[code->tableIndex()]; size_t size = table->size(); writeValue(bp, code->dstOffset(), size); @@ -1224,8 +1180,8 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, TableCopy* code = (TableCopy*)programCounter; ASSERT(code->dstIndex() < instance->module()->numberOfTableTypes()); ASSERT(code->srcIndex() < instance->module()->numberOfTableTypes()); - Table* dstTable = tables[code->dstIndex()]; - Table* srcTable = tables[code->srcIndex()]; + Table* dstTable = instance->m_tables[code->dstIndex()]; + Table* srcTable = instance->m_tables[code->srcIndex()]; uint32_t dstIndex = readValue(bp, code->srcOffsets()[0]); uint32_t srcIndex = readValue(bp, code->srcOffsets()[1]); @@ -1242,7 +1198,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, { TableFill* code = (TableFill*)programCounter; ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes()); - Table* table = tables[code->tableIndex()]; + Table* table = instance->m_tables[code->tableIndex()]; int32_t index = readValue(bp, code->srcOffsets()[0]); void* ptr = readValue(bp, code->srcOffsets()[1]); @@ -1264,7 +1220,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state, int32_t size = readValue(bp, code->srcOffsets()[2]); ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes()); - Table* table = tables[code->tableIndex()]; + Table* table = instance->m_tables[code->tableIndex()]; table->init(state, instance, &sg, dstStart, srcStart, size); ADD_PROGRAM_COUNTER(TableInit); NEXT_INSTRUCTION(); diff --git a/src/interpreter/Interpreter.h b/src/interpreter/Interpreter.h index 09eb12fbe..1ad74ed4f 100644 --- a/src/interpreter/Interpreter.h +++ b/src/interpreter/Interpreter.h @@ -18,6 +18,10 @@ #define __WalrusInterpreter__ #include "runtime/ExecutionState.h" +#include "runtime/Function.h" +#include "runtime/Instance.h" +#include "runtime/Module.h" +#include "runtime/Tag.h" #include "interpreter/ByteCode.h" namespace Walrus { @@ -28,19 +32,78 @@ class Table; class Global; class Interpreter { -public: - static ByteCodeStackOffset* interpret(ExecutionState& state, - uint8_t* bp); - private: friend class ByteCodeTable; + friend class DefinedFunction; + friend class DefinedFunctionWithTryCatch; + + template + ALWAYS_INLINE static void callInterpreter(ExecutionState& state, DefinedFunction* function, uint8_t* bp, ByteCodeStackOffset* offsets, + uint16_t parameterOffsetCount, uint16_t resultOffsetCount) + { + ExecutionState newState(state, function); + CHECK_STACK_LIMIT(newState); + + auto moduleFunction = function->moduleFunction(); + ALLOCA(uint8_t, functionStackBase, moduleFunction->requiredStackSize()); + + // init parameter space + for (size_t i = 0; i < parameterOffsetCount; i++) { + ((size_t*)functionStackBase)[i] = *((size_t*)(bp + offsets[i])); + } + + size_t programCounter = reinterpret_cast(moduleFunction->byteCode()); + ByteCodeStackOffset* resultOffsets; + if (considerException) { + while (true) { + try { + resultOffsets = interpret(newState, programCounter, functionStackBase, function->instance()); + break; + } catch (std::unique_ptr& e) { + for (size_t i = e->m_programCounterInfo.size(); i > 0; i--) { + if (e->m_programCounterInfo[i - 1].first == &newState) { + programCounter = e->m_programCounterInfo[i - 1].second; + break; + } + } + if (e->isUserException()) { + bool isCatchSucessful = false; + Tag* tag = e->tag().value(); + size_t offset = programCounter - reinterpret_cast(moduleFunction->byteCode()); + for (const auto& item : moduleFunction->catchInfo()) { + if (item.m_tryStart <= offset && offset < item.m_tryEnd) { + if (item.m_tagIndex == std::numeric_limits::max() || function->instance()->tag(item.m_tagIndex) == tag) { + programCounter = item.m_catchStartPosition + reinterpret_cast(moduleFunction->byteCode()); + uint8_t* sp = functionStackBase + item.m_stackSizeToBe; + if (item.m_tagIndex != std::numeric_limits::max() && tag->functionType()->paramStackSize()) { + memcpy(sp, e->userExceptionData().data(), tag->functionType()->paramStackSize()); + } + isCatchSucessful = true; + break; + } + } + } + if (isCatchSucessful) { + continue; + } + } + throw std::unique_ptr(std::move(e)); + } + } + } else { + resultOffsets = interpret(newState, programCounter, functionStackBase, function->instance()); + } + + offsets += parameterOffsetCount; + for (size_t i = 0; i < resultOffsetCount; i++) { + *((size_t*)(bp + offsets[i])) = *((size_t*)(functionStackBase + resultOffsets[i])); + } + } + static ByteCodeStackOffset* interpret(ExecutionState& state, size_t programCounter, uint8_t* bp, - Instance* instance, - Memory** memories, - Table** tables, - Global** globals); + Instance* instance); static void callOperation(ExecutionState& state, size_t& programCounter, diff --git a/src/parser/WASMParser.cpp b/src/parser/WASMParser.cpp index ca85478ae..09ddd8de6 100644 --- a/src/parser/WASMParser.cpp +++ b/src/parser/WASMParser.cpp @@ -471,8 +471,8 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { bool m_inInitExpr; Walrus::ModuleFunction* m_currentFunction; Walrus::FunctionType* m_currentFunctionType; - uint32_t m_initialFunctionStackSize; - uint32_t m_functionStackSizeSoFar; + uint16_t m_initialFunctionStackSize; + uint16_t m_functionStackSizeSoFar; std::vector m_vmStack; std::vector m_blockInfo; @@ -1879,6 +1879,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { { BlockInfo b(BlockInfo::TryCatch, sigType, *this); m_blockInfo.push_back(b); + m_currentFunction->m_hasTryCatch = true; } void processCatchExpr(Index tagIndex) diff --git a/src/runtime/Function.cpp b/src/runtime/Function.cpp index 6cdf6fbb0..a225f7565 100644 --- a/src/runtime/Function.cpp +++ b/src/runtime/Function.cpp @@ -20,6 +20,8 @@ #include "runtime/Store.h" #include "interpreter/Interpreter.h" #include "runtime/Module.h" +#include "runtime/Tag.h" +#include "runtime/Instance.h" #include "runtime/Value.h" namespace Walrus { @@ -28,7 +30,12 @@ DefinedFunction* DefinedFunction::createDefinedFunction(Store* store, Instance* instance, ModuleFunction* moduleFunction) { - DefinedFunction* func = new DefinedFunction(instance, moduleFunction); + DefinedFunction* func; + if (moduleFunction->hasTryCatch()) { + func = new DefinedFunctionWithTryCatch(instance, moduleFunction); + } else { + func = new DefinedFunction(instance, moduleFunction); + } store->appendExtern(func); return func; } @@ -88,22 +95,13 @@ void DefinedFunction::call(ExecutionState& state, Value* argv, Value* result) void DefinedFunction::interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets, uint16_t parameterOffsetCount, uint16_t resultOffsetCount) { - ExecutionState newState(state, this); - CHECK_STACK_LIMIT(newState); - - ALLOCA(uint8_t, functionStackBase, m_moduleFunction->requiredStackSize()); - - // init parameter space - for (size_t i = 0; i < parameterOffsetCount; i++) { - ((size_t*)functionStackBase)[i] = *((size_t*)(bp + offsets[i])); - } - - auto resultOffsets = Interpreter::interpret(newState, functionStackBase); + Interpreter::callInterpreter(state, this, bp, offsets, parameterOffsetCount, resultOffsetCount); +} - offsets += parameterOffsetCount; - for (size_t i = 0; i < resultOffsetCount; i++) { - *((size_t*)(bp + offsets[i])) = *((size_t*)(functionStackBase + resultOffsets[i])); - } +void DefinedFunctionWithTryCatch::interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets, + uint16_t parameterOffsetCount, uint16_t resultOffsetCount) +{ + Interpreter::callInterpreter(state, this, bp, offsets, parameterOffsetCount, resultOffsetCount); } ImportedFunction* ImportedFunction::createImportedFunction(Store* store, diff --git a/src/runtime/Function.h b/src/runtime/Function.h index ad6e87659..c4d8ab4ed 100644 --- a/src/runtime/Function.h +++ b/src/runtime/Function.h @@ -121,6 +121,22 @@ class DefinedFunction : public Function { ModuleFunction* m_moduleFunction; }; +class DefinedFunctionWithTryCatch : public DefinedFunction { + friend class DefinedFunction; + friend class Module; + +public: + virtual void interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets, + uint16_t parameterOffsetCount, uint16_t resultOffsetCount) override; + +protected: + DefinedFunctionWithTryCatch(Instance* instance, + ModuleFunction* moduleFunction) + : DefinedFunction(instance, moduleFunction) + { + } +}; + class ImportedFunction : public Function { public: typedef std::function ImportedFunctionCallback; diff --git a/src/runtime/Instance.cpp b/src/runtime/Instance.cpp index 106264766..44212845a 100644 --- a/src/runtime/Instance.cpp +++ b/src/runtime/Instance.cpp @@ -51,6 +51,11 @@ void Instance::freeInstance(Instance* instance) Instance::Instance(Module* module) : m_module(module) + , m_memories(nullptr) + , m_globals(nullptr) + , m_tables(nullptr) + , m_functions(nullptr) + , m_tags(nullptr) { module->store()->appendInstance(this); } diff --git a/src/runtime/Instance.h b/src/runtime/Instance.h index 5b7a27be3..dfb2984fc 100644 --- a/src/runtime/Instance.h +++ b/src/runtime/Instance.h @@ -124,7 +124,7 @@ class Instance : public Object { Instance(Module* module); ~Instance() {} - static size_t alignedSize() + static constexpr size_t alignedSize() { return (sizeof(Instance) + sizeof(void*) - 1) & ~(sizeof(void*) - 1); } diff --git a/src/runtime/Module.cpp b/src/runtime/Module.cpp index 6f252fc9e..9c24b8e1c 100644 --- a/src/runtime/Module.cpp +++ b/src/runtime/Module.cpp @@ -31,8 +31,9 @@ namespace Walrus { ModuleFunction::ModuleFunction(FunctionType* functionType) - : m_functionType(functionType) - , m_requiredStackSize(std::max(m_functionType->paramStackSize(), m_functionType->resultStackSize())) + : m_hasTryCatch(false) + , m_requiredStackSize(std::max(functionType->paramStackSize(), functionType->resultStackSize())) + , m_functionType(functionType) { } @@ -218,19 +219,16 @@ Instance* Module::instantiate(ExecutionState& state, const ExternVector& imports struct RunData { Instance* instance; Module* module; - Value::Type type; ModuleFunction* mf; size_t index; - } data = { instance, this, globalType->type(), globalType->function(), globIndex }; + } data = { instance, this, globalType->function(), globIndex }; Walrus::Trap trap; trap.run([](Walrus::ExecutionState& state, void* d) { RunData* data = reinterpret_cast(d); - ALLOCA(uint8_t, functionStackBase, data->mf->requiredStackSize()); - - DefinedFunction fakeFunction(data->instance, data->mf); - ExecutionState newState(state, &fakeFunction); - auto resultOffset = Interpreter::interpret(newState, functionStackBase); - data->instance->m_globals[data->index]->setValue(Value(data->type, functionStackBase + resultOffset[0])); + DefinedFunctionWithTryCatch fakeFunction(data->instance, data->mf); + Value result; + fakeFunction.call(state, nullptr, &result); + data->instance->m_globals[data->index]->setValue(result); }, &data); } @@ -250,20 +248,14 @@ Instance* Module::instantiate(ExecutionState& state, const ExternVector& imports struct RunData { Element* elem; Instance* instance; - Module* module; uint32_t& index; - } data = { elem, instance, this, index }; + } data = { elem, instance, index }; Walrus::Trap trap; trap.run([](Walrus::ExecutionState& state, void* d) { RunData* data = reinterpret_cast(d); - ALLOCA(uint8_t, functionStackBase, data->elem->moduleFunction()->requiredStackSize()); - - DefinedFunction fakeFunction(data->instance, - data->elem->moduleFunction()); - ExecutionState newState(state, &fakeFunction); - - auto resultOffset = Interpreter::interpret(newState, functionStackBase); - Value offset(Value::I32, functionStackBase + resultOffset[0]); + DefinedFunctionWithTryCatch fakeFunction(data->instance, data->elem->moduleFunction()); + Value offset; + fakeFunction.call(state, nullptr, &offset); data->index = offset.asI32(); }, &data); @@ -303,14 +295,10 @@ Instance* Module::instantiate(ExecutionState& state, const ExternVector& imports auto result = trap.run([](Walrus::ExecutionState& state, void* d) { RunData* data = reinterpret_cast(d); if (data->init->moduleFunction()->currentByteCodeSize()) { - ALLOCA(uint8_t, functionStackBase, data->init->moduleFunction()->requiredStackSize()); - - DefinedFunction fakeFunction(data->instance, - data->init->moduleFunction()); - ExecutionState newState(state, &fakeFunction); - - auto resultOffset = Interpreter::interpret(newState, functionStackBase); - Value offset(Value::I32, functionStackBase + resultOffset[0]); + DefinedFunctionWithTryCatch fakeFunction(data->instance, + data->init->moduleFunction()); + Value offset; + fakeFunction.call(state, nullptr, &offset); Memory* m = data->instance->memory(0); const auto& initData = data->init->initData(); diff --git a/src/runtime/Module.h b/src/runtime/Module.h index 1e7d2866f..7c41acf8a 100644 --- a/src/runtime/Module.h +++ b/src/runtime/Module.h @@ -175,8 +175,9 @@ class ModuleFunction { ModuleFunction(FunctionType* functionType); + bool hasTryCatch() const { return m_hasTryCatch; } + uint16_t requiredStackSize() const { return m_requiredStackSize; } FunctionType* functionType() const { return m_functionType; } - uint32_t requiredStackSize() const { return m_requiredStackSize; } template void pushByteCode(const CodeType& code) @@ -218,8 +219,9 @@ class ModuleFunction { } private: + bool m_hasTryCatch; + uint16_t m_requiredStackSize; FunctionType* m_functionType; - uint32_t m_requiredStackSize; ValueTypeVector m_local; Vector> m_byteCode; #if !defined(NDEBUG)