Skip to content

Commit

Permalink
Implement WASI file functions and improve WASI stucture
Browse files Browse the repository at this point in the history
Introduce class WasiFunction to enable WASI to acces Instance resources.
Implement WASI types and path_open and fd_close functions.

Signed-off-by: Adam Laszlo Kulcsar <[email protected]>
  • Loading branch information
kulcsaradam committed Sep 7, 2023
1 parent 451a31e commit b2295da
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 18 deletions.
1 change: 0 additions & 1 deletion src/interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ namespace Walrus {

ByteCodeTable g_byteCodeTable;


// SIMD Structures
template <typename T, uint8_t L>
struct SIMDValue {
Expand Down
45 changes: 45 additions & 0 deletions src/runtime/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,49 @@ void ImportedFunction::call(ExecutionState& state, Value* argv, Value* result)
m_callback(newState, argv, result, m_data);
}

WasiFunction* WasiFunction::createWasiFunction(Store* store,
FunctionType* functionType,
WasiFunctionCallback callback,
Instance* instance)
{
WasiFunction* func = new WasiFunction(functionType,
callback,
instance);
store->appendExtern(func);
return func;
}

void WasiFunction::interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount)
{
const FunctionType* ft = functionType();
const ValueTypeVector& paramTypeInfo = ft->param();
const ValueTypeVector& resultTypeInfo = ft->result();

ALLOCA(Value, paramVector, sizeof(Value) * paramTypeInfo.size());
ALLOCA(Value, resultVector, sizeof(Value) * resultTypeInfo.size());

size_t offsetIndex = 0;
size_t size = paramTypeInfo.size();
Value* paramVectorStart = paramVector;
for (size_t i = 0; i < size; i++) {
paramVector[i] = Value(paramTypeInfo[i], bp + offsets[offsetIndex]);
offsetIndex += valueFunctionCopyCount(paramTypeInfo[i]);
}

call(state, paramVectorStart, resultVector);

for (size_t i = 0; i < resultTypeInfo.size(); i++) {
resultVector[i].writeToMemory(bp + offsets[offsetIndex]);
offsetIndex += valueFunctionCopyCount(resultTypeInfo[i]);
}
}

void WasiFunction::call(ExecutionState& state, Value* argv, Value* result)
{
ExecutionState newState(state, this);
CHECK_STACK_LIMIT(newState);
m_callback(newState, argv, result, this->m_runningInstance);
}

} // namespace Walrus
48 changes: 48 additions & 0 deletions src/runtime/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class FunctionType;
class ModuleFunction;
class DefinedFunction;
class ImportedFunction;
class WasiFunction;

class Function : public Extern {
public:
Expand Down Expand Up @@ -71,6 +72,10 @@ class Function : public Extern {
{
return false;
}
virtual bool isWasiFunction() const
{
return false;
}

DefinedFunction* asDefinedFunction()
{
Expand All @@ -84,6 +89,12 @@ class Function : public Extern {
return reinterpret_cast<ImportedFunction*>(this);
}

WasiFunction* asWasiFunctions()
{
assert(isWasiFunction());
return reinterpret_cast<WasiFunction*>(this);
}

protected:
Function(FunctionType* functionType)
: m_functionType(functionType)
Expand Down Expand Up @@ -168,6 +179,43 @@ class ImportedFunction : public Function {
void* m_data;
};

class WasiFunction : public Function {
public:
typedef std::function<void(ExecutionState& state, Value* argv, Value* result, Instance* instance)> WasiFunctionCallback;

static WasiFunction* createWasiFunction(Store* store,
FunctionType* functionType,
WasiFunctionCallback callback,
Instance* instance);

virtual bool isWasiFunction() const override
{
return true;
}

void setRunningInstance(Instance* instance)
{
m_runningInstance = instance;
}

virtual void call(ExecutionState& state, Value* argv, Value* result) override;
virtual void interpreterCall(ExecutionState& state, uint8_t* bp, ByteCodeStackOffset* offsets,
uint16_t parameterOffsetCount, uint16_t resultOffsetCount) override;

protected:
WasiFunction(FunctionType* functionType,
WasiFunctionCallback callback,
Instance* instance)
: Function(functionType)
, m_callback(callback)
, m_runningInstance(instance)
{
}

WasiFunctionCallback m_callback;
Instance* m_runningInstance;
};

} // namespace Walrus

#endif // __WalrusFunction__
6 changes: 5 additions & 1 deletion src/runtime/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "interpreter/ByteCode.h"
#include "interpreter/Interpreter.h"
#include "parser/WASMParser.h"
#include "wasi/Wasi.h"

namespace Walrus {

Expand Down Expand Up @@ -137,7 +138,10 @@ Instance* Module::instantiate(ExecutionState& state, const ExternVector& imports
if (!imports[i]->asFunction()->functionType()->equals(m_imports[i]->functionType())) {
Trap::throwException(state, "imported function type mismatch");
}
instance->m_functions[funcIndex++] = imports[i]->asFunction();
instance->m_functions[funcIndex] = imports[i]->asFunction();
if (imports[i]->asFunction()->isWasiFunction()) {
instance->m_functions[funcIndex++]->asWasiFunctions()->setRunningInstance(instance);
}
break;
}
case ImportType::Global: {
Expand Down
25 changes: 25 additions & 0 deletions src/runtime/SpecTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ class SpecTestFunctionTypes {
// The R is meant to represent the results, after R are the result types.
NONE = 0,
I32R,
I32_RI32,
RI32,
I64R,
F32R,
F64R,
I32F32R,
F64F64R,
I32I32I32I32I64I64I32_RI32I32,
INVALID,
INDEX_NUM,
};
Expand All @@ -57,6 +59,14 @@ class SpecTestFunctionTypes {
param->push_back(Value::Type::I32);
m_vector[index++] = new FunctionType(param, result);
}
{
// I32_RI32
param = new ValueTypeVector();
result = new ValueTypeVector();
param->push_back(Value::Type::I32);
result->push_back(Value::Type::I32);
m_vector[index++] = new FunctionType(param, result);
}
{
// RI32
param = new ValueTypeVector();
Expand Down Expand Up @@ -101,6 +111,21 @@ class SpecTestFunctionTypes {
param->push_back(Value::Type::F64);
m_vector[index++] = new FunctionType(param, result);
}
{
// I32I32I32I32I64I64I32_RI32I32
param = new ValueTypeVector();
result = new ValueTypeVector();
param->push_back(Value::Type::I32);
param->push_back(Value::Type::I32);
param->push_back(Value::Type::I32);
param->push_back(Value::Type::I32);
param->push_back(Value::Type::I64);
param->push_back(Value::Type::I64);
param->push_back(Value::Type::I32);
result->push_back(Value::Type::I32);
result->push_back(Value::Type::I32);
m_vector[index++] = new FunctionType(param, result);
}
{
// INVALID
param = new ValueTypeVector();
Expand Down
20 changes: 10 additions & 10 deletions src/shell/Shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static void printF64(double v)
printf("%s : f64\n", formatDecmialString(ss.str()).c_str());
}

static Trap::TrapResult executeWASM(Store* store, const std::string& filename, const std::vector<uint8_t>& src, SpecTestFunctionTypes& functionTypes, WASI* wasi,
static Trap::TrapResult executeWASM(Store* store, const std::string& filename, const std::vector<uint8_t>& src, SpecTestFunctionTypes& functionTypes,
std::map<std::string, Instance*>* registeredInstanceMap = nullptr)
{
auto parseResult = WASMParser::parseBinary(store, filename, src.data(), src.size());
Expand Down Expand Up @@ -231,11 +231,11 @@ static Trap::TrapResult executeWASM(Store* store, const std::string& filename, c
nullptr));
}
} else if (import->moduleName() == "wasi_snapshot_preview1") {
Walrus::WASI::WasiFunc* wasiImportFunc = wasi->find(import->fieldName());
Walrus::WASI::WasiFunc* wasiImportFunc = WASI::find(import->fieldName());
if (wasiImportFunc != nullptr) {
FunctionType* fn = functionTypes[wasiImportFunc->functionType];
if (fn->equals(import->functionType())) {
importValues.push_back(ImportedFunction::createImportedFunction(
importValues.push_back(WasiFunction::createWasiFunction(
store,
const_cast<FunctionType*>(import->functionType()),
wasiImportFunc->ptr,
Expand Down Expand Up @@ -579,7 +579,7 @@ static Instance* fetchInstance(wabt::Var& moduleVar, std::map<size_t, Instance*>
return registeredInstanceMap[moduleVar.name()];
}

static void executeWAST(Store* store, const std::string& filename, const std::vector<uint8_t>& src, SpecTestFunctionTypes& functionTypes, WASI* wasi)
static void executeWAST(Store* store, const std::string& filename, const std::vector<uint8_t>& src, SpecTestFunctionTypes& functionTypes)
{
auto lexer = wabt::WastLexer::CreateBufferLexer("test.wabt", src.data(), src.size());
if (!lexer) {
Expand Down Expand Up @@ -607,7 +607,7 @@ static void executeWAST(Store* store, const std::string& filename, const std::ve
case wabt::CommandType::ScriptModule: {
auto* moduleCommand = static_cast<wabt::ModuleCommand*>(command.get());
auto buf = readModuleData(&moduleCommand->module);
auto trapResult = executeWASM(store, filename, buf->data, functionTypes, wasi, &registeredInstanceMap);
auto trapResult = executeWASM(store, filename, buf->data, functionTypes, &registeredInstanceMap);
RELEASE_ASSERT(!trapResult.exception);
instanceMap[commandCount] = store->getLastInstance();
if (moduleCommand->module.name.size()) {
Expand Down Expand Up @@ -668,7 +668,7 @@ static void executeWAST(Store* store, const std::string& filename, const std::ve
auto tsm = dynamic_cast<wabt::TextScriptModule*>(m);
RELEASE_ASSERT(tsm);
auto buf = readModuleData(&tsm->module);
auto trapResult = executeWASM(store, filename, buf->data, functionTypes, wasi, &registeredInstanceMap);
auto trapResult = executeWASM(store, filename, buf->data, functionTypes, &registeredInstanceMap);
RELEASE_ASSERT(trapResult.exception);
std::string& s = trapResult.exception->message();
RELEASE_ASSERT(s.find(assertModuleUninstantiable->text) == 0);
Expand Down Expand Up @@ -705,7 +705,7 @@ static void executeWAST(Store* store, const std::string& filename, const std::ve
} else {
buf = dsm->data;
}
auto trapResult = executeWASM(store, filename, buf, functionTypes, wasi);
auto trapResult = executeWASM(store, filename, buf, functionTypes);
RELEASE_ASSERT(trapResult.exception);
std::string& actual = trapResult.exception->message();
printf("assertModuleInvalid (expect compile error: '%s', actual '%s'(line: %d)) : OK\n", assertModuleInvalid->text.data(), actual.data(), assertModuleInvalid->module->location().line);
Expand All @@ -728,7 +728,7 @@ static void executeWAST(Store* store, const std::string& filename, const std::ve
} else {
buf = dsm->data;
}
auto trapResult = executeWASM(store, filename, buf, functionTypes, wasi);
auto trapResult = executeWASM(store, filename, buf, functionTypes);
RELEASE_ASSERT(trapResult.exception);
break;
}
Expand Down Expand Up @@ -905,14 +905,14 @@ int main(int argc, char* argv[])
if (!argParser.exportToRun.empty()) {
runExports(store, filePath, buf, argParser.exportToRun);
} else {
auto trapResult = executeWASM(store, filePath, buf, functionTypes, wasi);
auto trapResult = executeWASM(store, filePath, buf, functionTypes);
if (trapResult.exception) {
fprintf(stderr, "Uncaught Exception: %s\n", trapResult.exception->message().data());
return -1;
}
}
} else if (endsWith(filePath, "wat") || endsWith(filePath, "wast")) {
executeWAST(store, filePath, buf, functionTypes, wasi);
executeWAST(store, filePath, buf, functionTypes);
}
} else {
printf("Cannot open file %s\n", filePath.data());
Expand Down
37 changes: 37 additions & 0 deletions src/wasi/Fd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023-present Samsung Electronics Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

namespace Walrus {

void WASI::fd_close(ExecutionState& state, Value* argv, Value* result, Instance* instance)
{
WASI::wasi_fd_t fd = static_cast<uint32_t>(argv[0].asI32());
for (auto file : WASI::m_wasiFiles) {
if (file->directory) {
result[0] = Value(static_cast<uint16_t>(WASI::wasi_errno::success));
}
if (file->fd == fd) {
file->ptr->close();
if (file->ptr->bad()) {
result[0] = Value(static_cast<uint16_t>(WASI::wasi_errno::noent));
}
result[0] = Value(static_cast<uint16_t>(WASI::wasi_errno::success));
}
}
result[0] = Value(static_cast<uint16_t>(WASI::wasi_errno::inval));
}

} // namespace Walrus
Loading

0 comments on commit b2295da

Please sign in to comment.