Skip to content

Commit

Permalink
Improve function call performance
Browse files Browse the repository at this point in the history
* Divide function into user-defined function and user-defined function /w try-catch
* Non try-catch user-defined function should not use c++ try-catch
* Reduce parameter count of Interpreter::interpret

Signed-off-by: Seonghyun Kim <[email protected]>
  • Loading branch information
ksh8281 authored and clover2123 committed Sep 7, 2023
1 parent 0a6c695 commit 451a31e
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 119 deletions.
80 changes: 18 additions & 62 deletions src/interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,53 +150,10 @@ ByteCodeTable::ByteCodeTable()
b.m_opcodeInAddress = const_cast<void*>(FillByteCodeOpcodeAddress[0]);
#endif
size_t pc = reinterpret_cast<size_t>(&b);
Interpreter::interpret(dummyState, pc, nullptr, nullptr, nullptr, nullptr, nullptr);
Interpreter::interpret(dummyState, pc, nullptr, nullptr);
#endif
}

ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
uint8_t* bp)
{
DefinedFunction* df = state.currentFunction()->asDefinedFunction();
ModuleFunction* mf = df->moduleFunction();
size_t programCounter = reinterpret_cast<size_t>(mf->byteCode());
Instance* instance = df->instance();
while (true) {
try {
return interpret(state, programCounter, bp, instance, instance->m_memories, instance->m_tables, instance->m_globals);
} catch (std::unique_ptr<Exception>& e) {
for (size_t i = e->m_programCounterInfo.size(); i > 0; i--) {
if (e->m_programCounterInfo[i - 1].first == &state) {
programCounter = e->m_programCounterInfo[i - 1].second;
break;
}
}
if (e->isUserException()) {
bool isCatchSucessful = false;
Tag* tag = e->tag().value();
size_t offset = programCounter - reinterpret_cast<size_t>(mf->byteCode());
for (const auto& item : mf->catchInfo()) {
if (item.m_tryStart <= offset && offset < item.m_tryEnd) {
if (item.m_tagIndex == std::numeric_limits<uint32_t>::max() || state.currentFunction()->asDefinedFunction()->instance()->tag(item.m_tagIndex) == tag) {
programCounter = item.m_catchStartPosition + reinterpret_cast<size_t>(mf->byteCode());
uint8_t* sp = bp + item.m_stackSizeToBe;
if (item.m_tagIndex != std::numeric_limits<uint32_t>::max() && tag->functionType()->paramStackSize()) {
memcpy(sp, e->userExceptionData().data(), tag->functionType()->paramStackSize());
}
isCatchSucessful = true;
break;
}
}
}
if (isCatchSucessful) {
continue;
}
}
throw std::unique_ptr<Exception>(std::move(e));
}
}
}

template <typename T>
ALWAYS_INLINE void writeValue(uint8_t* bp, ByteCodeStackOffset offset, const T& v)
{
Expand Down Expand Up @@ -488,11 +445,10 @@ static void initAddressToOpcodeTable()
ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
size_t programCounter,
uint8_t* bp,
Instance* instance,
Memory** memories,
Table** tables,
Global** globals)
Instance* instance)
{
Memory** memories = reinterpret_cast<Memory**>(reinterpret_cast<uintptr_t>(instance) + Instance::alignedSize());

state.m_programCounterPointer = &programCounter;

#define ADD_PROGRAM_COUNTER(codeName) programCounter += sizeof(codeName);
Expand Down Expand Up @@ -947,7 +903,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalGet32* code = (GlobalGet32*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
globals[code->index()]->value().writeNBytesToMemory<4>(bp + code->dstOffset());
instance->m_globals[code->index()]->value().writeNBytesToMemory<4>(bp + code->dstOffset());
ADD_PROGRAM_COUNTER(GlobalGet32);
NEXT_INSTRUCTION();
}
Expand All @@ -957,7 +913,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalGet64* code = (GlobalGet64*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
globals[code->index()]->value().writeNBytesToMemory<8>(bp + code->dstOffset());
instance->m_globals[code->index()]->value().writeNBytesToMemory<8>(bp + code->dstOffset());
ADD_PROGRAM_COUNTER(GlobalGet64);
NEXT_INSTRUCTION();
}
Expand All @@ -967,7 +923,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalGet128* code = (GlobalGet128*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
globals[code->index()]->value().writeNBytesToMemory<16>(bp + code->dstOffset());
instance->m_globals[code->index()]->value().writeNBytesToMemory<16>(bp + code->dstOffset());
ADD_PROGRAM_COUNTER(GlobalGet128);
NEXT_INSTRUCTION();
}
Expand All @@ -977,7 +933,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalSet32* code = (GlobalSet32*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
Value& val = globals[code->index()]->value();
Value& val = instance->m_globals[code->index()]->value();
val.readFromStack<4>(bp + code->srcOffset());
ADD_PROGRAM_COUNTER(GlobalSet32);
NEXT_INSTRUCTION();
Expand All @@ -988,7 +944,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalSet64* code = (GlobalSet64*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
Value& val = globals[code->index()]->value();
Value& val = instance->m_globals[code->index()]->value();
val.readFromStack<8>(bp + code->srcOffset());
ADD_PROGRAM_COUNTER(GlobalSet64);
NEXT_INSTRUCTION();
Expand All @@ -999,7 +955,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
GlobalSet128* code = (GlobalSet128*)programCounter;
ASSERT(code->index() < instance->module()->numberOfGlobalTypes());
Value& val = globals[code->index()]->value();
Value& val = instance->m_globals[code->index()]->value();
val.readFromStack<16>(bp + code->srcOffset());
ADD_PROGRAM_COUNTER(GlobalSet128);
NEXT_INSTRUCTION();
Expand Down Expand Up @@ -1161,7 +1117,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableGet* code = (TableGet*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
void* val = table->getElement(state, readValue<uint32_t>(bp, code->srcOffset()));
writeValue(bp, code->dstOffset(), val);

Expand All @@ -1174,7 +1130,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableSet* code = (TableSet*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
void* ptr = readValue<void*>(bp, code->src1Offset());
table->setElement(state, readValue<uint32_t>(bp, code->src0Offset()), ptr);

Expand All @@ -1187,7 +1143,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableGrow* code = (TableGrow*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
size_t size = table->size();

uint64_t newSize = (uint64_t)readValue<uint32_t>(bp, code->src1Offset()) + size;
Expand All @@ -1210,7 +1166,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableSize* code = (TableSize*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
size_t size = table->size();
writeValue<uint32_t>(bp, code->dstOffset(), size);

Expand All @@ -1224,8 +1180,8 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
TableCopy* code = (TableCopy*)programCounter;
ASSERT(code->dstIndex() < instance->module()->numberOfTableTypes());
ASSERT(code->srcIndex() < instance->module()->numberOfTableTypes());
Table* dstTable = tables[code->dstIndex()];
Table* srcTable = tables[code->srcIndex()];
Table* dstTable = instance->m_tables[code->dstIndex()];
Table* srcTable = instance->m_tables[code->srcIndex()];

uint32_t dstIndex = readValue<uint32_t>(bp, code->srcOffsets()[0]);
uint32_t srcIndex = readValue<uint32_t>(bp, code->srcOffsets()[1]);
Expand All @@ -1242,7 +1198,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
{
TableFill* code = (TableFill*)programCounter;
ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];

int32_t index = readValue<int32_t>(bp, code->srcOffsets()[0]);
void* ptr = readValue<void*>(bp, code->srcOffsets()[1]);
Expand All @@ -1264,7 +1220,7 @@ ByteCodeStackOffset* Interpreter::interpret(ExecutionState& state,
int32_t size = readValue<int32_t>(bp, code->srcOffsets()[2]);

ASSERT(code->tableIndex() < instance->module()->numberOfTableTypes());
Table* table = tables[code->tableIndex()];
Table* table = instance->m_tables[code->tableIndex()];
table->init(state, instance, &sg, dstStart, srcStart, size);
ADD_PROGRAM_COUNTER(TableInit);
NEXT_INSTRUCTION();
Expand Down
79 changes: 71 additions & 8 deletions src/interpreter/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#define __WalrusInterpreter__

#include "runtime/ExecutionState.h"
#include "runtime/Function.h"
#include "runtime/Instance.h"
#include "runtime/Module.h"
#include "runtime/Tag.h"
#include "interpreter/ByteCode.h"

namespace Walrus {
Expand All @@ -28,19 +32,78 @@ class Table;
class Global;

class Interpreter {
public:
static ByteCodeStackOffset* interpret(ExecutionState& state,
uint8_t* bp);

private:
friend class ByteCodeTable;
friend class DefinedFunction;
friend class DefinedFunctionWithTryCatch;

template <const bool considerException>
ALWAYS_INLINE static void callInterpreter(ExecutionState& state, DefinedFunction* function, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount)
{
ExecutionState newState(state, function);
CHECK_STACK_LIMIT(newState);

auto moduleFunction = function->moduleFunction();
ALLOCA(uint8_t, functionStackBase, moduleFunction->requiredStackSize());

// init parameter space
for (size_t i = 0; i < parameterOffsetCount; i++) {
((size_t*)functionStackBase)[i] = *((size_t*)(bp + offsets[i]));
}

size_t programCounter = reinterpret_cast<size_t>(moduleFunction->byteCode());
ByteCodeStackOffset* resultOffsets;
if (considerException) {
while (true) {
try {
resultOffsets = interpret(newState, programCounter, functionStackBase, function->instance());
break;
} catch (std::unique_ptr<Exception>& e) {
for (size_t i = e->m_programCounterInfo.size(); i > 0; i--) {
if (e->m_programCounterInfo[i - 1].first == &newState) {
programCounter = e->m_programCounterInfo[i - 1].second;
break;
}
}
if (e->isUserException()) {
bool isCatchSucessful = false;
Tag* tag = e->tag().value();
size_t offset = programCounter - reinterpret_cast<size_t>(moduleFunction->byteCode());
for (const auto& item : moduleFunction->catchInfo()) {
if (item.m_tryStart <= offset && offset < item.m_tryEnd) {
if (item.m_tagIndex == std::numeric_limits<uint32_t>::max() || function->instance()->tag(item.m_tagIndex) == tag) {
programCounter = item.m_catchStartPosition + reinterpret_cast<size_t>(moduleFunction->byteCode());
uint8_t* sp = functionStackBase + item.m_stackSizeToBe;
if (item.m_tagIndex != std::numeric_limits<uint32_t>::max() && tag->functionType()->paramStackSize()) {
memcpy(sp, e->userExceptionData().data(), tag->functionType()->paramStackSize());
}
isCatchSucessful = true;
break;
}
}
}
if (isCatchSucessful) {
continue;
}
}
throw std::unique_ptr<Exception>(std::move(e));
}
}
} else {
resultOffsets = interpret(newState, programCounter, functionStackBase, function->instance());
}

offsets += parameterOffsetCount;
for (size_t i = 0; i < resultOffsetCount; i++) {
*((size_t*)(bp + offsets[i])) = *((size_t*)(functionStackBase + resultOffsets[i]));
}
}

static ByteCodeStackOffset* interpret(ExecutionState& state,
size_t programCounter,
uint8_t* bp,
Instance* instance,
Memory** memories,
Table** tables,
Global** globals);
Instance* instance);

static void callOperation(ExecutionState& state,
size_t& programCounter,
Expand Down
5 changes: 3 additions & 2 deletions src/parser/WASMParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
bool m_inInitExpr;
Walrus::ModuleFunction* m_currentFunction;
Walrus::FunctionType* m_currentFunctionType;
uint32_t m_initialFunctionStackSize;
uint32_t m_functionStackSizeSoFar;
uint16_t m_initialFunctionStackSize;
uint16_t m_functionStackSizeSoFar;

std::vector<VMStackInfo> m_vmStack;
std::vector<BlockInfo> m_blockInfo;
Expand Down Expand Up @@ -1879,6 +1879,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
{
BlockInfo b(BlockInfo::TryCatch, sigType, *this);
m_blockInfo.push_back(b);
m_currentFunction->m_hasTryCatch = true;
}

void processCatchExpr(Index tagIndex)
Expand Down
30 changes: 14 additions & 16 deletions src/runtime/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "runtime/Store.h"
#include "interpreter/Interpreter.h"
#include "runtime/Module.h"
#include "runtime/Tag.h"
#include "runtime/Instance.h"
#include "runtime/Value.h"

namespace Walrus {
Expand All @@ -28,7 +30,12 @@ DefinedFunction* DefinedFunction::createDefinedFunction(Store* store,
Instance* instance,
ModuleFunction* moduleFunction)
{
DefinedFunction* func = new DefinedFunction(instance, moduleFunction);
DefinedFunction* func;
if (moduleFunction->hasTryCatch()) {
func = new DefinedFunctionWithTryCatch(instance, moduleFunction);
} else {
func = new DefinedFunction(instance, moduleFunction);
}
store->appendExtern(func);
return func;
}
Expand Down Expand Up @@ -88,22 +95,13 @@ void DefinedFunction::call(ExecutionState& state, Value* argv, Value* result)
void DefinedFunction::interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount)
{
ExecutionState newState(state, this);
CHECK_STACK_LIMIT(newState);

ALLOCA(uint8_t, functionStackBase, m_moduleFunction->requiredStackSize());

// init parameter space
for (size_t i = 0; i < parameterOffsetCount; i++) {
((size_t*)functionStackBase)[i] = *((size_t*)(bp + offsets[i]));
}

auto resultOffsets = Interpreter::interpret(newState, functionStackBase);
Interpreter::callInterpreter<false>(state, this, bp, offsets, parameterOffsetCount, resultOffsetCount);
}

offsets += parameterOffsetCount;
for (size_t i = 0; i < resultOffsetCount; i++) {
*((size_t*)(bp + offsets[i])) = *((size_t*)(functionStackBase + resultOffsets[i]));
}
void DefinedFunctionWithTryCatch::interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount)
{
Interpreter::callInterpreter<true>(state, this, bp, offsets, parameterOffsetCount, resultOffsetCount);
}

ImportedFunction* ImportedFunction::createImportedFunction(Store* store,
Expand Down
16 changes: 16 additions & 0 deletions src/runtime/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ class DefinedFunction : public Function {
ModuleFunction* m_moduleFunction;
};

class DefinedFunctionWithTryCatch : public DefinedFunction {
friend class DefinedFunction;
friend class Module;

public:
virtual void interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount) override;

protected:
DefinedFunctionWithTryCatch(Instance* instance,
ModuleFunction* moduleFunction)
: DefinedFunction(instance, moduleFunction)
{
}
};

class ImportedFunction : public Function {
public:
typedef std::function<void(ExecutionState& state, Value* argv, Value* result, void* data)> ImportedFunctionCallback;
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/Instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ void Instance::freeInstance(Instance* instance)

Instance::Instance(Module* module)
: m_module(module)
, m_memories(nullptr)
, m_globals(nullptr)
, m_tables(nullptr)
, m_functions(nullptr)
, m_tags(nullptr)
{
module->store()->appendInstance(this);
}
Expand Down
Loading

0 comments on commit 451a31e

Please sign in to comment.