Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WASI: add args_get and args_sizes_get #308

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/shell/Shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct ParseOptions {
// WASI options
std::vector<std::string> wasi_envs;
std::vector<std::pair<std::string, std::string>> wasi_dirs;
int argsIndex = -1;
};

static uint32_t s_JITFlags = 0;
Expand Down Expand Up @@ -1039,7 +1040,7 @@ static void runExports(Store* store, const std::string& filename, const std::vec
&data);
}

static void parseArguments(int argc, char* argv[], ParseOptions& options)
static void parseArguments(int argc, const char* argv[], ParseOptions& options)
{
for (int i = 1; i < argc; i++) {
if (strlen(argv[i]) >= 2 && argv[i][0] == '-') { // parse command line option
Expand Down Expand Up @@ -1083,6 +1084,15 @@ static void parseArguments(int argc, char* argv[], ParseOptions& options)
options.wasi_dirs.push_back(std::make_pair(argv[i + 2], argv[i + 1]));
i += 2;
continue;
} else if (strcmp(argv[i], "--args") == 0) {
if (i + 1 == argc || argv[i + 1][0] == '-') {
fprintf(stderr, "error: --args requires an argument\n");
exit(1);
}
++i;
options.fileNames.emplace_back(argv[i]);
options.argsIndex = i;
break;
} else if (strcmp(argv[i], "--help") == 0) {
fprintf(stdout, "Usage: walrus [OPTIONS] <INPUT>\n\n");
fprintf(stdout, "OPTIONS:\n");
Expand All @@ -1095,6 +1105,7 @@ static void parseArguments(int argc, char* argv[], ParseOptions& options)
#endif
fprintf(stdout, "\t--mapdirs <HOST_DIR> <VIRTUAL_DIR>\n\t\tMap real directories to virtual ones for WASI functions to use.\n\t\tExample: ./walrus test.wasm --mapdirs this/real/directory/ this/virtual/directory\n\n");
fprintf(stdout, "\t--env\n\t\tShare host environment to walrus WASI.\n\n");
fprintf(stdout, "\t--args <MODULE_FILE_NAME> [<ARG1> <ARG2> ... <ARGN>]\n\t\tRun Webassembly module with arguments: must be followed by the name of the Webassembly module file, then optionally following arguments which are passed on to the module\n\t\tExample: ./walrus --args test.wasm 'hello' 'world' 42\n\n");
exit(0);
}
}
Expand All @@ -1117,7 +1128,7 @@ static void parseArguments(int argc, char* argv[], ParseOptions& options)
}
}

int main(int argc, char* argv[])
int main(int argc, const char* argv[])
{
#ifndef NDEBUG
setbuf(stdout, NULL);
Expand Down Expand Up @@ -1164,8 +1175,8 @@ int main(int argc, char* argv[])
init_options.out = 1;
init_options.err = 2;
init_options.fd_table_size = 3;
init_options.argc = 0;
init_options.argv = nullptr;
init_options.argc = (options.argsIndex == -1 ? 0 : argc - options.argsIndex);
init_options.argv = (options.argsIndex == -1 ? nullptr : argv + options.argsIndex);
init_options.envp = envp.data();
init_options.preopenc = dirs.size();
init_options.preopens = dirs.data();
Expand Down
133 changes: 103 additions & 30 deletions src/wasi/WASI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,23 @@ static void* get_memory_pointer(Instance* instance, Value& value, size_t size)
return memory->buffer() + offset;
}

template <class T, int maxSize>
class TemporaryData {
public:
TemporaryData(size_t size)
{
/* Check that more memory is requested than the maximum provided by malloc. */
if (size > ((~static_cast<size_t>(0)) / sizeof(T))) {
m_data = nullptr;
return;
}

size *= sizeof(T);

if (size <= sizeof(m_stackData)) {
m_data = m_stackData;
} else {
m_data = malloc(size);
m_data = reinterpret_cast<T*>(malloc(size));
}
}

Expand All @@ -59,14 +68,14 @@ class TemporaryData {
}
}

void* data()
T* data()
{
return m_data;
}

private:
void* m_stackData[8];
void* m_data;
T m_stackData[maxSize];
T* m_data;
};

void WASI::initialize(uvwasi_t* uvwasi)
Expand All @@ -93,6 +102,43 @@ WASI::WasiFuncInfo* WASI::find(const std::string& funcName)
return nullptr;
}

void WASI::args_get(ExecutionState& state, Value* argv, Value* result, Instance* instance)
{
uvwasi_size_t argc;
uvwasi_size_t bufSize;
uvwasi_args_sizes_get(WASI::g_uvwasi, &argc, &bufSize);

uint32_t* uvArgv = reinterpret_cast<uint32_t*>(get_memory_pointer(instance, argv[0], argc * sizeof(uint32_t)));
char* uvArgBuf = reinterpret_cast<char*>(get_memory_pointer(instance, argv[1], bufSize));

if (uvArgv == nullptr || uvArgBuf == nullptr) {
result[0] = Value(WasiErrNo::inval);
return;
}

TemporaryData<void*, 8> pointers(argc);

if (pointers.data() == nullptr) {
result[0] = Value(WasiErrNo::inval);
return;
}

char** data = reinterpret_cast<char**>(pointers.data());
result[0] = Value(static_cast<uint16_t>(uvwasi_args_get(WASI::g_uvwasi, data, uvArgBuf)));

for (uvwasi_size_t i = 0; i < argc; i++) {
uvArgv[i] = data[i] - uvArgBuf;
}
}

void WASI::args_sizes_get(ExecutionState& state, Value* argv, Value* result, Instance* instance)
{
uvwasi_size_t* uvArgc = reinterpret_cast<uvwasi_size_t*>(get_memory_pointer(instance, argv[0], sizeof(uint32_t)));
uvwasi_size_t* uvArgvBufSize = reinterpret_cast<uvwasi_size_t*>(get_memory_pointer(instance, argv[1], sizeof(uint32_t)));

result[0] = Value(static_cast<int16_t>(uvwasi_args_sizes_get(WASI::g_uvwasi, uvArgc, uvArgvBufSize)));
}

void WASI::proc_exit(ExecutionState& state, Value* argv, Value* result, Instance* instance)
{
ASSERT(argv[0].type() == Value::I32);
Expand Down Expand Up @@ -123,51 +169,78 @@ void WASI::clock_time_get(ExecutionState& state, Value* argv, Value* result, Ins
void WASI::fd_write(ExecutionState& state, Value* argv, Value* result, Instance* instance)
{
uint32_t fd = argv[0].asI32();
uint32_t iovptr = argv[1].asI32();
uint32_t iovcnt = argv[2].asI32();
uint32_t out = argv[3].asI32();
size_t iovsLen = static_cast<size_t>(argv[2].asI32());
uint32_t* nwritten = reinterpret_cast<uint32_t*>(get_memory_pointer(instance, argv[3], sizeof(uint32_t)));
uint32_t* iovptr = reinterpret_cast<uint32_t*>(get_memory_pointer(instance, argv[1], iovsLen * (sizeof(uint32_t) << 1)));

if (uint64_t(iovptr) + iovcnt >= instance->memory(0)->sizeInByte()) {
result[0] = Value(static_cast<int16_t>(WasiErrNo::inval));
if (iovptr == nullptr || nwritten == nullptr) {
result[0] = Value(WasiErrNo::inval);
return;
}

std::vector<uvwasi_ciovec_t> iovs(iovcnt);
for (uint32_t i = 0; i < iovcnt; i++) {
iovs[i].buf = instance->memory(0)->buffer() + *reinterpret_cast<uint32_t*>(instance->memory(0)->buffer() + iovptr + i * 8);
iovs[i].buf_len = *reinterpret_cast<uint32_t*>(instance->memory(0)->buffer() + iovptr + 4 + i * 8);
}
TemporaryData<uvwasi_ciovec_t, 8> iovsBuffer(iovsLen);
uvwasi_ciovec_t* iovs = iovsBuffer.data();
uint64_t sizeInByte = instance->memory(0)->sizeInByte();
uint8_t* buffer = instance->memory(0)->buffer();

uvwasi_size_t* out_addr = (uvwasi_size_t*)(instance->memory(0)->buffer() + out);
for (uint32_t i = 0; i < iovsLen; i++) {
if (iovptr[1] > sizeInByte || iovptr[0] > sizeInByte - iovptr[1]) {
result[0] = Value(WasiErrNo::inval);
return;
}

iovs[i].buf = buffer + iovptr[0];
iovs[i].buf_len = iovptr[1];
iovptr += 2;
}

result[0] = Value(static_cast<int16_t>(uvwasi_fd_write(WASI::g_uvwasi, fd, iovs.data(), iovs.size(), out_addr)));
*(instance->memory(0)->buffer() + out) = *out_addr;
result[0] = Value(uvwasi_fd_write(WASI::g_uvwasi, fd, iovs, iovsLen, nwritten));
}

void WASI::fd_read(ExecutionState& state, Value* argv, Value* result, Instance* instance)
{
uint32_t fd = argv[0].asI32();
uint32_t iovptr = argv[1].asI32();
uint32_t iovcnt = argv[2].asI32();
uint32_t out = argv[3].asI32();

std::vector<uvwasi_iovec_t> iovs(iovcnt);
for (uint32_t i = 0; i < iovcnt; i++) {
iovs[i].buf = instance->memory(0)->buffer() + *reinterpret_cast<uint32_t*>(instance->memory(0)->buffer() + iovptr + i * 8);
iovs[i].buf_len = *reinterpret_cast<uint32_t*>(instance->memory(0)->buffer() + iovptr + 4 + i * 8);
size_t iovsLen = static_cast<size_t>(argv[2].asI32());
uint32_t* nread = reinterpret_cast<uint32_t*>(get_memory_pointer(instance, argv[3], sizeof(uint32_t)));
uint32_t* iovptr = reinterpret_cast<uint32_t*>(get_memory_pointer(instance, argv[1], iovsLen * (sizeof(uint32_t) << 1)));

if (iovptr == nullptr || nread == nullptr) {
result[0] = Value(WasiErrNo::inval);
return;
}

uvwasi_size_t* out_addr = (uvwasi_size_t*)(instance->memory(0)->buffer() + out);
TemporaryData<uvwasi_iovec_t, 8> iovsBuffer(iovsLen);
uvwasi_iovec_t* iovs = iovsBuffer.data();
uint64_t sizeInByte = instance->memory(0)->sizeInByte();
uint8_t* buffer = instance->memory(0)->buffer();

for (uint32_t i = 0; i < iovsLen; i++) {
if (iovptr[1] > sizeInByte || iovptr[0] > sizeInByte - iovptr[1]) {
result[0] = Value(WasiErrNo::inval);
return;
}

iovs[i].buf = buffer + iovptr[0];
iovs[i].buf_len = iovptr[1];
iovptr += 2;
}

result[0] = Value(static_cast<int16_t>(uvwasi_fd_read(WASI::g_uvwasi, fd, iovs.data(), iovs.size(), out_addr)));
*(instance->memory(0)->buffer() + out) = *out_addr;
result[0] = Value(uvwasi_fd_read(WASI::g_uvwasi, fd, iovs, iovsLen, nread));
}

void WASI::fd_close(ExecutionState& state, Value* argv, Value* result, Instance* instance)
{
uint32_t fd = argv[0].asI32();

result[0] = Value(static_cast<int16_t>(uvwasi_fd_close(WASI::g_uvwasi, fd)));
result[0] = Value(uvwasi_fd_close(WASI::g_uvwasi, fd));
}

void WASI::fd_fdstat_get(ExecutionState& state, Value* argv, Value* result, Instance* instance)
{
uint32_t fd = argv[0].asI32();
uvwasi_fdstat_t* fdstat = reinterpret_cast<uvwasi_fdstat_t*>(get_memory_pointer(instance, argv[1], sizeof(uvwasi_fdstat_t)));

result[0] = Value(uvwasi_fd_fdstat_get(WASI::g_uvwasi, fd, fdstat));
}

void WASI::fd_seek(ExecutionState& state, Value* argv, Value* result, Instance* instance)
Expand Down Expand Up @@ -218,7 +291,7 @@ void WASI::environ_get(ExecutionState& state, Value* argv, Value* result, Instan
return;
}

TemporaryData pointers(count * sizeof(void*));
TemporaryData<void*, 8> pointers(count);

if (pointers.data() == nullptr) {
result[0] = Value(WasiErrNo::inval);
Expand Down
6 changes: 6 additions & 0 deletions src/wasi/WASI.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class WASI {
// https://github.com/WebAssembly/WASI/blob/main/legacy/preview1/docs.md

#define FOR_EACH_WASI_FUNC(F) \
F(args_get, I32I32_RI32) \
F(args_sizes_get, I32I32_RI32) \
F(proc_exit, I32R) \
F(proc_raise, I32_RI32) \
F(clock_res_get, I32I32_RI32) \
Expand All @@ -44,6 +46,7 @@ class WASI {
F(fd_write, I32I32I32I32_RI32) \
F(fd_read, I32I32I32I32_RI32) \
F(fd_close, I32_RI32) \
F(fd_fdstat_get, I32I32_RI32) \
F(fd_seek, I32I64I32I32_RI32) \
F(path_open, I32I32I32I32I32I64I64I32I32_RI32) \
F(environ_get, I32I32_RI32) \
Expand Down Expand Up @@ -152,13 +155,16 @@ class WASI {

private:
// wasi functions
static void args_get(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void args_sizes_get(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void proc_exit(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void proc_raise(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void clock_res_get(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void clock_time_get(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void fd_write(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void fd_read(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void fd_close(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void fd_fdstat_get(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void fd_seek(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void path_open(ExecutionState& state, Value* argv, Value* result, Instance* instance);
static void environ_get(ExecutionState& state, Value* argv, Value* result, Instance* instance);
Expand Down
75 changes: 75 additions & 0 deletions test/wasi/args.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
(module
(import "wasi_snapshot_preview1" "fd_write" (func $wasi_fd_write (param i32 i32 i32 i32) (result i32)))
(import "wasi_snapshot_preview1" "args_get" (func $wasi_args_get (param i32 i32) (result i32)))
(import "wasi_snapshot_preview1" "args_sizes_get" (func $wasi_args_sizes_get (param i32 i32) (result i32)))
(memory 1)

(export "memory" (memory 0))
(export "print_args" (func $print_args))
(data (i32.const 100) "500" )

(func $print_args
(local $i i32)
i32.const 0 ;; args count
i32.const 12 ;; args overall size in characters
call $wasi_args_sizes_get
drop

i32.const 500 ;; argp
i32.const 600 ;; argp[0]
call $wasi_args_get
drop

;; Memory[100] = 600, start of output string.
i32.const 100
i32.const 600
i32.store

;; Memory[104] = size of output string.
i32.const 104
i32.const 12
i32.load
i32.store

;; Replace '\0' with '\n' for readable printing.
i32.const 0
local.set $i
(loop $loop
i32.const 600
local.get $i
i32.add
i32.load8_u

i32.eqz
(if
(then
i32.const 600
local.get $i
i32.add
i32.const 10
i32.store8
)
)

local.get $i
i32.const 1
i32.add
local.tee $i

i32.const 12
i32.load
i32.lt_u
br_if $loop
)

(call $wasi_fd_write
(i32.const 1) ;;file descriptor
(i32.const 100) ;;offset of str offset
(i32.const 1) ;;iovec length
(i32.const 200) ;;result offset
)
drop
)
)

(assert_return (invoke "print_args"))
Loading
Loading