Skip to content

Commit

Permalink
feat(core): Core interface for backend-agnostic operators (#14)
Browse files Browse the repository at this point in the history
* An abomination

Scares the shit out of the typescript gods and me but at least it works

Co-authored-by: Dean Srebnik <[email protected]>

* feat: WGSL pre-processor and redo-types

Co-authored-by: Dean Srebnik <[email protected]>

* rework core, again!

* core, `operatorOnBackend` and typings

* fix wasm build

* shitty matmul bench

Co-authored-by: Dean Srebnik <[email protected]>

* fix operators for new `BackendOperator` types

* core operator with both wasm and webgpu!

* update README and logo

* selu

* u64 support in wgsl

* fmt

* binary and unary core operators

* update logo

* feat: transpose core operator

Co-authored-by: Dean Srebnik <[email protected]>
  • Loading branch information
eliassjogreen and load1n9 authored Aug 11, 2022
1 parent 8276043 commit 1a50eae
Show file tree
Hide file tree
Showing 57 changed files with 5,893 additions and 25,822 deletions.
30 changes: 21 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
# neo

[![Tags](https://img.shields.io/github/release/denosaurs/neo)](https://github.com/denosaurs/neo/releases)
[![CI Status](https://img.shields.io/github/workflow/status/denosaurs/neo/check)](https://github.com/denosaurs/neo/actions)
[![Dependencies](https://img.shields.io/github/workflow/status/denosaurs/neo/depsbot?label=dependencies)](https://github.com/denosaurs/depsbot)
[![License](https://img.shields.io/github/license/denosaurs/neo)](https://github.com/denosaurs/neo/blob/main/LICENSE)

`neo` is a module for working with matrices and other linear algebra,
accelerated using WebGPU.
<p align="center">
<img src="./assets/neo.svg" width="80rem" />
<br/>
<h1 align="center">neo</h1>
</p>

<p align="center">
<a href="https://github.com/denosaurs/neo/releases">
<img alt="GitHub release (latest by date including pre-releases)" src="https://img.shields.io/github/v/release/denosaurs/neo?include_prereleases" />
</a>
<a href="https://github.com/denosaurs/neo/actions">
<img alt="GitHub Workflow Status" src="https://img.shields.io/github/workflow/status/denosaurs/neo/check" />
</a>
<a href="https://github.com/denosaurs/neo/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/github/license/denosaurs/neo" />
</a>
</p>

`neo` is the module for working with matrices, ndarrays, tensors and linear
algebra in deno. Accelerated using WebGPU and WASM it runs anywhere a browser
runs.

## Maintainers

Expand Down
28,883 changes: 3,524 additions & 25,359 deletions assets/neo.ai

Large diffs are not rendered by default.

Binary file added assets/neo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
494 changes: 494 additions & 0 deletions assets/neo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
119 changes: 119 additions & 0 deletions backend/core/core.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import {
Backend,
BackendConstructor,
BackendOperator,
BackendType,
} from "../types/backend.ts";
import { Data } from "../types/data.ts";
import { getDataConstructorFor } from "../util/data.ts";
import { WasmBackend } from "../wasm/backend.ts";
import { WebGPUBackend } from "../webgpu/backend.ts";

export interface CoreBackendOptions {
backends?: (
| Backend
| BackendType
| BackendConstructor<WasmBackend | WebGPUBackend>
)[];
}

export class Core {
initalized = false;
supported = true;

backends: Map<BackendType, Backend> = new Map();

async initialize(options?: CoreBackendOptions): Promise<void> {
const backends = options?.backends ?? ["wasm", "webgpu"];

for (let backend of backends) {
if (typeof backend === "string") {
switch (backend) {
case "wasm":
backend = new WasmBackend();
break;
case "webgpu":
backend = new WebGPUBackend();
break;
}
}

if (!("type" in backend)) {
backend = new backend();
}

this.backends.set(backend.type, backend);
}

const initializing = [];
for (const [type, backend] of this.backends) {
if (!backend.supported) {
this.backends.delete(type);
continue;
}

if (!backend.initalized) {
initializing.push(backend.initialize());
}
}
await Promise.all(initializing);

this.initalized = true;
}
}

export async function operatorOnBackend<
B extends Backend,
D extends Data[],
A extends Record<string, unknown> | undefined,
R = void,
>(
backend: B,
operator: BackendOperator<B, D, A, R>,
data: Data<D extends Data<infer T>[] ? T : never>[],
args: A,
): Promise<R> {
const DataConstructor = getDataConstructorFor(backend.type);

// TODO: Should and could definitely be optimized with a persistant data pool and "meta" allocator
const convertedData: (Data | [Data, Data])[] = [];
const conversions: Promise<void>[] = [];
for (let index = 0; index < data.length; index++) {
const entry = data[index];

if (entry.backend.type !== backend.type) {
conversions.push((async () => {
const content = await entry.get();
const temporaryData = await DataConstructor.from(
backend,
content,
entry.type,
);
convertedData[index] = [entry, temporaryData];
})());
} else {
convertedData[index] = entry;
}
}
await Promise.all(conversions);

const operatorData = convertedData.map((data) =>
Array.isArray(data) ? data[1] : data
) as D;
const result = await operator(backend, operatorData, args);

// Deconvert the converted data to its original form in case all of the data does not share the same backend type
const deconversions: Promise<void>[] = [];
for (const dataPair of convertedData) {
if (Array.isArray(dataPair)) {
deconversions.push((async () => {
const [data, temporaryData] = dataPair;
await data.set(await temporaryData.get());
temporaryData.dispose();
})());
}
}
await Promise.all(deconversions);

return result;
}
51 changes: 51 additions & 0 deletions backend/core/operators/binary.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { Data, DataType } from "../../types/data.ts";
import { WasmBackend } from "../../wasm/backend.ts";
import { Core, operatorOnBackend } from "../core.ts";
import * as wasm from "../../wasm/operators/binary.ts";
import * as webgpu from "../../webgpu/operators/binary.ts";
import { WebGPUBackend } from "../../webgpu/backend.ts";
import { BackendOperator } from "../../types/backend.ts";
import { WebGPUData } from "../../webgpu/data.ts";
import { WasmData } from "../../wasm/data.ts";

export function binary<T extends DataType>(
wasmOperator: BackendOperator<
WasmBackend,
[WasmData<T>, WasmData<T>, WasmData<T>],
undefined,
void
>,
webgpuOperator: BackendOperator<
WebGPUBackend,
[WebGPUData<T>, WebGPUData<T>, WebGPUData<T>],
undefined,
Promise<void>
>,
) {
return async function (core: Core, data: [Data<T>, Data<T>, Data<T>]) {
if (Math.max(...data.map(({ length }) => length)) <= 80 * 80) {
return await operatorOnBackend(
core.backends.get("wasm")! as WasmBackend,
wasmOperator,
data,
undefined,
);
} else {
return await operatorOnBackend(
core.backends.get("webgpu")! as WebGPUBackend,
webgpuOperator,
data,
undefined,
);
}
};
}

export const add = binary(wasm.add, webgpu.add);
export const sub = binary(wasm.sub, webgpu.sub);
export const mul = binary(wasm.mul, webgpu.mul);
export const div = binary(wasm.div, webgpu.div);
export const mod = binary(wasm.mod, webgpu.mod);
export const min = binary(wasm.min, webgpu.min);
export const max = binary(wasm.max, webgpu.max);
export const prelu = binary(wasm.prelu, webgpu.prelu);
28 changes: 28 additions & 0 deletions backend/core/operators/matmul.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { Data } from "../../types/data.ts";
import { WasmBackend } from "../../wasm/backend.ts";
import { Core, operatorOnBackend } from "../core.ts";
import { matmul as wasmMatmul } from "../../wasm/operators/matmul.ts";
import { matmul as webgpuMatmul } from "../../webgpu/operators/matmul.ts";
import { WebGPUBackend } from "../../webgpu/backend.ts";

export async function matmul<T extends "f32" | "u32" | "i32">(
core: Core,
data: [Data<T>, Data<T>, Data<T>],
args: { m: number; n: number; k: number },
) {
if (Math.max(...data.map(({ length }) => length)) <= 72 * 72) {
return await operatorOnBackend(
core.backends.get("wasm")! as WasmBackend,
wasmMatmul,
data,
args,
);
} else {
return await operatorOnBackend(
core.backends.get("webgpu")! as WebGPUBackend,
webgpuMatmul,
data,
args,
);
}
}
28 changes: 28 additions & 0 deletions backend/core/operators/transpose.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { Data } from "../../types/data.ts";
import { WasmBackend } from "../../wasm/backend.ts";
import { Core, operatorOnBackend } from "../core.ts";
import { transpose as wasmTranspose } from "../../wasm/operators/transpose.ts";
import { transpose as webgpuTranspose } from "../../webgpu/operators/transpose.ts";
import { WebGPUBackend } from "../../webgpu/backend.ts";

export async function transpose<T extends "f32" | "u32" | "i32">(
core: Core,
data: [Data<T>, Data<T>],
args: { w: number; h: number },
) {
if (Math.max(...data.map(({ length }) => length)) <= 90 * 90) {
return await operatorOnBackend(
core.backends.get("wasm")! as WasmBackend,
wasmTranspose,
data,
args,
);
} else {
return await operatorOnBackend(
core.backends.get("webgpu")! as WebGPUBackend,
webgpuTranspose,
data,
args,
);
}
}
67 changes: 67 additions & 0 deletions backend/core/operators/unary.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { Data, DataType } from "../../types/data.ts";
import { WasmBackend } from "../../wasm/backend.ts";
import { Core, operatorOnBackend } from "../core.ts";
import * as wasm from "../../wasm/operators/unary.ts";
import * as webgpu from "../../webgpu/operators/unary.ts";
import { WebGPUBackend } from "../../webgpu/backend.ts";
import { BackendOperator } from "../../types/backend.ts";
import { WebGPUData } from "../../webgpu/data.ts";
import { WasmData } from "../../wasm/data.ts";

export function unary<T extends DataType>(
wasmOperator: BackendOperator<
WasmBackend,
[WasmData<T>, WasmData<T>],
undefined,
void
>,
webgpuOperator: BackendOperator<
WebGPUBackend,
[WebGPUData<T>, WebGPUData<T>],
undefined,
Promise<void>
>,
) {
return async function (core: Core, data: [Data<T>, Data<T>]) {
if (Math.max(...data.map(({ length }) => length)) <= 90 * 90) {
return await operatorOnBackend(
core.backends.get("wasm")! as WasmBackend,
wasmOperator,
data,
undefined,
);
} else {
return await operatorOnBackend(
core.backends.get("webgpu")! as WebGPUBackend,
webgpuOperator,
data,
undefined,
);
}
};
}

export const abs = unary(wasm.abs, webgpu.abs);
export const linear = unary<"f32" | "u32" | "i32">(wasm.linear, webgpu.linear);
export const neg = unary<"f32" | "i32">(wasm.neg, webgpu.neg);
export const inc = unary<"f32" | "u32" | "i32">(wasm.inc, webgpu.inc);
export const dec = unary<"f32" | "u32" | "i32">(wasm.dec, webgpu.dec);
export const relu = unary<"f32" | "i32">(wasm.relu, webgpu.relu);
export const relu6 = unary<"f32" | "i32">(wasm.relu6, webgpu.relu6);
export const ceil = unary<"f32">(wasm.ceil, webgpu.ceil);
export const floor = unary<"f32">(wasm.floor, webgpu.floor);
export const round = unary<"f32">(wasm.round, webgpu.round);
export const sqrt = unary<"f32">(wasm.sqrt, webgpu.sqrt);
export const rsqrt = unary<"f32">(wasm.rsqrt, webgpu.rsqrt);
export const selu = unary<"f32">(wasm.selu, webgpu.selu);
export const sigmoid = unary<"f32">(wasm.sigmoid, webgpu.sigmoid);
export const square = unary<"f32" | "u32" | "i32">(wasm.square, webgpu.square);
export const cos = unary<"f32">(wasm.cos, webgpu.cos);
export const cosh = unary<"f32">(wasm.cosh, webgpu.cosh);
export const sin = unary<"f32">(wasm.sin, webgpu.sin);
export const sinh = unary<"f32">(wasm.sinh, webgpu.sinh);
export const tan = unary<"f32">(wasm.tan, webgpu.tan);
export const tanh = unary<"f32">(wasm.tanh, webgpu.tanh);
export const exp = unary<"f32">(wasm.exp, webgpu.exp);
export const elu = unary<"f32">(wasm.elu, webgpu.elu);
export const log = unary<"f32">(wasm.log, webgpu.log);
2 changes: 1 addition & 1 deletion backend/rust/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ macro_rules! binary_operator {
($identifier:ident, $type:ty, $closure:expr) => {
#[no_mangle]
pub unsafe extern "C" fn $identifier(
len: usize,
a_data: *const $type,
b_data: *const $type,
c_data: *mut $type,
len: usize,
) {
let a_data = &core::slice::from_raw_parts(a_data, len)[..len];
let b_data = &core::slice::from_raw_parts(b_data, len)[..len];
Expand Down
2 changes: 2 additions & 0 deletions backend/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
#![feature(core_intrinsics)]

pub mod binary;
pub mod matmul;
pub mod transpose;
pub mod unary;
22 changes: 9 additions & 13 deletions backend/rust/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,19 @@ macro_rules! matmul_impl {
($identifier:ident, $type:ty) => {
#[no_mangle]
pub unsafe extern "C" fn $identifier(
a_data: *const $type,
b_data: *const $type,
c_data: *mut $type,
m: usize,
n: usize,
k: usize,
a_data: *const $type,
b_data: *const $type,
c_data: *mut $type,
) {
let a_len = m * k;
let b_len = k * n;
let c_len = m * n;
let a_data =
&core::slice::from_raw_parts(a_data, a_len)[..a_len];
let b_data =
&core::slice::from_raw_parts(b_data, b_len)[..b_len];
let c_data =
&mut core::slice::from_raw_parts_mut(c_data, c_len)[..c_len];
let a_data = &core::slice::from_raw_parts(a_data, a_len)[..a_len];
let b_data = &core::slice::from_raw_parts(b_data, b_len)[..b_len];
let c_data = &mut core::slice::from_raw_parts_mut(c_data, c_len)[..c_len];

for l_index in 0..m {
for m_index in 0..k {
Expand All @@ -27,10 +24,9 @@ macro_rules! matmul_impl {
l_index * k + m_index,
m_index * n + n_index,
);

*c_data.get_unchecked_mut(i) +=
a_data.get_unchecked(j) * b_data.get_unchecked(k)
;

*c_data.get_unchecked_mut(i) +=
a_data.get_unchecked(j) * b_data.get_unchecked(k);
}
}
}
Expand Down
Loading

0 comments on commit 1a50eae

Please sign in to comment.