Skip to content

Commit

Permalink
feat: wasm backend (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliassjogreen authored May 7, 2022
1 parent b1a0f66 commit 8276043
Show file tree
Hide file tree
Showing 30 changed files with 333 additions and 305 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
uses: actions/checkout@v2

- name: Setup latest deno version
uses: denolib/setup-deno@v2
uses: denoland/setup-deno@main
with:
deno-version: v1.x

Expand All @@ -28,7 +28,7 @@ jobs:
uses: actions/checkout@v2

- name: Setup latest deno version
uses: denolib/setup-deno@v2
uses: denoland/setup-deno@main
with:
deno-version: v1.x

Expand Down
10 changes: 9 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[workspace]
members = [
"backend/rust",
"backend/wasm"
# TODO: "backend/native"
]

[profile.release]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
[![License](https://img.shields.io/github/license/denosaurs/neo)](https://github.com/denosaurs/neo/blob/main/LICENSE)

`neo` is a module for working with matrices and other linear algebra,
accelerated using WebGPU and wasm.
accelerated using WebGPU.

## Maintainers

Expand All @@ -21,4 +21,4 @@ Pull request, issues and feedback are very welcome. Code style is formatted with

### Licence

Copyright 2021-2022, the denosaurs team. All rights reserved. MIT license.
Copyright 2021, the denosaurs team. All rights reserved. MIT license.
12 changes: 12 additions & 0 deletions backend/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "neo_rust"
license = "MIT"
version = "0.1.0"
authors = ["Elias Sjögreen"]
edition = "2021"

[lib]
path = "lib.rs"

[dependencies]
wee_alloc = "0.4.5"
50 changes: 50 additions & 0 deletions backend/rust/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
macro_rules! binary_operator {
($identifier:ident, $type:ty, $closure:expr) => {
#[no_mangle]
pub unsafe extern "C" fn $identifier(
a_data: *const $type,
b_data: *const $type,
c_data: *mut $type,
len: usize,
) {
let a_data = &core::slice::from_raw_parts(a_data, len)[..len];
let b_data = &core::slice::from_raw_parts(b_data, len)[..len];
let c_data = &mut core::slice::from_raw_parts_mut(c_data, len)[..len];

for i in 0..len {
c_data[i] = $closure(a_data[i], b_data[i]);
}
}
};
}

binary_operator!(add_f32, f32, |a, b| a + b);
binary_operator!(add_u32, u32, |a, b| core::intrinsics::unchecked_add(a, b));
binary_operator!(add_i32, i32, |a, b| core::intrinsics::unchecked_add(a, b));

binary_operator!(sub_f32, f32, |a, b| a - b);
binary_operator!(sub_u32, u32, |a, b| core::intrinsics::unchecked_sub(a, b));
binary_operator!(sub_i32, i32, |a, b| core::intrinsics::unchecked_sub(a, b));

binary_operator!(mul_f32, f32, |a, b| a * b);
binary_operator!(mul_u32, u32, |a, b| core::intrinsics::unchecked_mul(a, b));
binary_operator!(mul_i32, i32, |a, b| core::intrinsics::unchecked_mul(a, b));

binary_operator!(div_f32, f32, |a, b| a / b);
binary_operator!(div_u32, u32, |a, b| core::intrinsics::unchecked_div(a, b));
binary_operator!(div_i32, i32, |a, b| core::intrinsics::unchecked_div(a, b));

binary_operator!(mod_f32, f32, |a, b| a % b);
binary_operator!(mod_u32, u32, |a, b| core::intrinsics::unchecked_rem(a, b));
binary_operator!(mod_i32, i32, |a, b| core::intrinsics::unchecked_rem(a, b));

binary_operator!(min_f32, f32, |a: f32, b: f32| a.min(b));
binary_operator!(min_u32, u32, |a: u32, b: u32| a.min(b));
binary_operator!(min_i32, i32, |a: i32, b: i32| a.min(b));

binary_operator!(max_f32, f32, |a: f32, b: f32| a.max(b));
binary_operator!(max_u32, u32, |a: u32, b: u32| a.max(b));
binary_operator!(max_i32, i32, |a: i32, b: i32| a.max(b));

binary_operator!(prelu_f32, f32, |a, b| if a < 0.0 { a * b } else { a });
binary_operator!(prelu_i32, i32, |a, b| if a < 0 { a * b } else { a });
5 changes: 5 additions & 0 deletions backend/rust/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#![no_std]
#![feature(core_intrinsics)]

pub mod binary;
pub mod unary;
18 changes: 9 additions & 9 deletions backend/wasm/matmul.rs → backend/rust/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
macro_rules! matmul_impl {
($identifier:ident, $type:ty) => {
#[no_mangle]
pub extern "C" fn $identifier(
m: usize,
n: usize,
k: usize,
pub unsafe extern "C" fn $identifier(
a_data: *const $type,
b_data: *const $type,
c_data: *mut $type,
m: usize,
n: usize,
k: usize,
) {
let a_len = m * k;
let b_len = k * n;
let c_len = m * n;
let a_data =
&unsafe { core::slice::from_raw_parts(a_data, a_len) }[..a_len];
&core::slice::from_raw_parts(a_data, a_len)[..a_len];
let b_data =
&unsafe { core::slice::from_raw_parts(b_data, b_len) }[..b_len];
&core::slice::from_raw_parts(b_data, b_len)[..b_len];
let c_data =
&mut unsafe { core::slice::from_raw_parts_mut(c_data, c_len) }[..c_len];
&mut core::slice::from_raw_parts_mut(c_data, c_len)[..c_len];

for l_index in 0..m {
for m_index in 0..k {
Expand All @@ -27,10 +27,10 @@ macro_rules! matmul_impl {
l_index * k + m_index,
m_index * n + n_index,
);
unsafe {

*c_data.get_unchecked_mut(i) +=
a_data.get_unchecked(j) * b_data.get_unchecked(k)
};
;
}
}
}
Expand Down
25 changes: 25 additions & 0 deletions backend/rust/transpose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
macro_rules! transpose_operator {
($identifier:ident, $type:ty) => {
#[no_mangle]
pub unsafe extern "C" fn $identifier(
a_data: *const $type,
b_data: *mut $type,
w: usize,
h: usize,
) {
let len = w * h;
let a_data = &core::slice::from_raw_parts(a_data, len)[..len];
let b_data = &mut core::slice::from_raw_parts_mut(b_data, len)[..len];

for x in 0..w {
for y in 0..h {
b_data[y + x * h] = a_data[x + y * w];
}
}
}
};
}

transpose_operator!(transpose_f32, f32);
transpose_operator!(transpose_u32, u32);
transpose_operator!(transpose_i32, i32);
93 changes: 93 additions & 0 deletions backend/rust/unary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
macro_rules! unary_operator {
($identifier:ident, $type:ty, $closure:expr) => {
#[no_mangle]
pub unsafe extern "C" fn $identifier(
a_data: *const $type,
b_data: *mut $type,
len: usize,
) {
let a_data = &core::slice::from_raw_parts(a_data, len)[..len];
let b_data = &mut core::slice::from_raw_parts_mut(b_data, len)[..len];

for i in 0..len {
b_data[i] = $closure(a_data[i]);
}
}
};
}

unary_operator!(abs_f32, f32, |a: f32| if a < 0.0 { -a } else { a });
unary_operator!(abs_i32, i32, |a: i32| if a < 0 { -a } else { a });

unary_operator!(linear_f32, f32, |a| a);
unary_operator!(linear_u32, u32, |a| a);
unary_operator!(linear_i32, i32, |a| a);

unary_operator!(neg_f32, f32, |a: f32| -a);
unary_operator!(neg_i32, i32, |a: i32| -a);

unary_operator!(inc_f32, f32, |a| a + 1.0);
unary_operator!(inc_u32, u32, |a| a + 1);
unary_operator!(inc_i32, i32, |a| a + 1);

unary_operator!(dec_f32, f32, |a| a - 1.0);
unary_operator!(dec_u32, u32, |a| a - 1);
unary_operator!(dec_i32, i32, |a| a - 1);

unary_operator!(relu_f32, f32, |a: f32| a.max(0.0));
unary_operator!(relu_i32, i32, |a: i32| a.max(0));

unary_operator!(relu6_f32, f32, |a: f32| a.clamp(0.0, 6.0));
unary_operator!(relu6_i32, i32, |a: i32| a.clamp(0, 6));

unary_operator!(ceil_f32, f32, |a: f32| core::intrinsics::ceilf32(a));
unary_operator!(floor_f32, f32, |a: f32| core::intrinsics::floorf32(a));
unary_operator!(round_f32, f32, |a: f32| core::intrinsics::roundf32(a));
unary_operator!(sqrt_f32, f32, |a: f32| core::intrinsics::sqrtf32(a));
unary_operator!(rsqrt_f32, f32, |a: f32| 1.0 / core::intrinsics::sqrtf32(a));
unary_operator!(sigmoid_f32, f32, |a: f32| 1.0
/ (1.0 + core::intrinsics::expf32(-1.0 * a)));

unary_operator!(square_f32, f32, |a| a * a);
unary_operator!(square_u32, u32, |a| a * a);
unary_operator!(square_i32, i32, |a| a * a);

unary_operator!(cos_f32, f32, |a: f32| core::intrinsics::cosf32(a));
unary_operator!(cosh_f32, f32, |a: f32| {
let e2x = core::intrinsics::expf32(-a);
(e2x + 1.0 / e2x) / 2.0
});

unary_operator!(sin_f32, f32, |a: f32| core::intrinsics::sinf32(a));
unary_operator!(sinh_f32, f32, |a: f32| {
let e2x = core::intrinsics::expf32(a);
(e2x - 1.0 / e2x) / 2.0
});

unary_operator!(tan_f32, f32, |a: f32| core::intrinsics::sinf32(a)
/ core::intrinsics::cosf32(a));
unary_operator!(tanh_f32, f32, |a: f32| {
let e2x = core::intrinsics::expf32(-2.0 * if a < 0.0 { -a } else { a });
(if a.is_nan() {
f32::NAN
} else if a.is_sign_negative() {
-1.0
} else {
1.0
}) * (1.0 - e2x)
/ (1.0 + e2x)
});

unary_operator!(exp_f32, f32, |a: f32| core::intrinsics::expf32(a));

unary_operator!(elu_f32, f32, |a: f32| if a >= 0.0 {
a
} else {
core::intrinsics::expf32(a) - 1.0
});

unary_operator!(log_f32, f32, |a: f32| if a < 0.0 {
f32::INFINITY
} else {
core::intrinsics::logf32(a)
});
3 changes: 2 additions & 1 deletion backend/wasm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "neo"
name = "neo_wasm"
license = "MIT"
version = "0.1.0"
authors = ["Elias Sjögreen"]
Expand All @@ -10,4 +10,5 @@ crate-type = ["cdylib"]
path = "lib.rs"

[dependencies]
neo_rust = { path = "../rust" }
wee_alloc = "0.4.5"
45 changes: 24 additions & 21 deletions backend/wasm/backend.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// import { wasm } from "../../util.ts";
import { Backend, BackendRequest, DataType } from "../types.ts";
import { WasmData } from "./data.ts";

Expand All @@ -7,7 +6,7 @@ const decoder = new TextDecoder();
export interface WasmBackendRequest<T extends DataType = DataType>
extends BackendRequest<T> {
func: string;
args: (number | bigint)[];
args: number[];
data: WasmData<T>[];
}

Expand All @@ -18,35 +17,30 @@ export class WasmBackend implements Backend {

instance!: WebAssembly.Instance;
memory!: WebAssembly.Memory;

alloc!: (size: number) => number;
dealloc!: (
pointer: number,
size: number,
) => void;
dealloc!: (ptr: number, size: number) => void;

async initialize(): Promise<void> {
if (this.initalized) {
return;
}

const { source } = await import("./wasm.js");
const { instance } = await WebAssembly.instantiate(source, {
const { instance: { exports } } = await WebAssembly.instantiate(source, {
env: {
panic: (pointer: number, len: number) => {
panic: (ptr: number, len: number) => {
const msg = decoder.decode(
new Uint8Array(memory.buffer, pointer, len),
new Uint8Array(this.memory.buffer, ptr, len),
);
throw new Error(msg);
},
},
});
const memory = instance.exports.memory as WebAssembly.Memory;

this.instance = instance;
this.memory = memory;
this.alloc = instance.exports.alloc as (size: number) => number;
this.dealloc = instance.exports.dealloc as (
pointer: number,
this.memory = exports.memory as WebAssembly.Memory;
this.alloc = exports.alloc as (size: number) => number;
this.dealloc = exports.dealloc as (
ptr: number,
size: number,
) => void;

Expand All @@ -55,10 +49,19 @@ export class WasmBackend implements Backend {

// deno-lint-ignore require-await
async execute(request: WasmBackendRequest): Promise<void> {
// deno-lint-ignore no-explicit-any
(this.instance.exports[request.func] as (...args: any[]) => any)(
...request.args,
...request.data.map((data) => data.pointer),
);
if (!this.initalized) {
throw new Error("WasmBackend is not initialized");
}

const func = this.instance
.exports[request.func] as (((...args: unknown[]) => unknown) | undefined);

if (func === undefined) {
throw new Error(`Could not find wasm function ${request.func}`);
}

const args = request.data.map((data) => data.ptr).concat(request.args);

func(...args);
}
}
Loading

0 comments on commit 8276043

Please sign in to comment.