-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(core):
Core
interface for backend-agnostic operators (#14)
* An abomination Scares the shit out of the typescript gods and me but at least it works Co-authored-by: Dean Srebnik <[email protected]> * feat: WGSL pre-processor and redo-types Co-authored-by: Dean Srebnik <[email protected]> * rework core, again! * core, `operatorOnBackend` and typings * fix wasm build * shitty matmul bench Co-authored-by: Dean Srebnik <[email protected]> * fix operators for new `BackendOperator` types * core operator with both wasm and webgpu! * update README and logo * selu * u64 support in wgsl * fmt * binary and unary core operators * update logo * feat: transpose core operator Co-authored-by: Dean Srebnik <[email protected]>
- Loading branch information
1 parent
8276043
commit 1a50eae
Showing
57 changed files
with
5,893 additions
and
25,822 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import { | ||
Backend, | ||
BackendConstructor, | ||
BackendOperator, | ||
BackendType, | ||
} from "../types/backend.ts"; | ||
import { Data } from "../types/data.ts"; | ||
import { getDataConstructorFor } from "../util/data.ts"; | ||
import { WasmBackend } from "../wasm/backend.ts"; | ||
import { WebGPUBackend } from "../webgpu/backend.ts"; | ||
|
||
export interface CoreBackendOptions { | ||
backends?: ( | ||
| Backend | ||
| BackendType | ||
| BackendConstructor<WasmBackend | WebGPUBackend> | ||
)[]; | ||
} | ||
|
||
export class Core { | ||
initalized = false; | ||
supported = true; | ||
|
||
backends: Map<BackendType, Backend> = new Map(); | ||
|
||
async initialize(options?: CoreBackendOptions): Promise<void> { | ||
const backends = options?.backends ?? ["wasm", "webgpu"]; | ||
|
||
for (let backend of backends) { | ||
if (typeof backend === "string") { | ||
switch (backend) { | ||
case "wasm": | ||
backend = new WasmBackend(); | ||
break; | ||
case "webgpu": | ||
backend = new WebGPUBackend(); | ||
break; | ||
} | ||
} | ||
|
||
if (!("type" in backend)) { | ||
backend = new backend(); | ||
} | ||
|
||
this.backends.set(backend.type, backend); | ||
} | ||
|
||
const initializing = []; | ||
for (const [type, backend] of this.backends) { | ||
if (!backend.supported) { | ||
this.backends.delete(type); | ||
continue; | ||
} | ||
|
||
if (!backend.initalized) { | ||
initializing.push(backend.initialize()); | ||
} | ||
} | ||
await Promise.all(initializing); | ||
|
||
this.initalized = true; | ||
} | ||
} | ||
|
||
export async function operatorOnBackend< | ||
B extends Backend, | ||
D extends Data[], | ||
A extends Record<string, unknown> | undefined, | ||
R = void, | ||
>( | ||
backend: B, | ||
operator: BackendOperator<B, D, A, R>, | ||
data: Data<D extends Data<infer T>[] ? T : never>[], | ||
args: A, | ||
): Promise<R> { | ||
const DataConstructor = getDataConstructorFor(backend.type); | ||
|
||
// TODO: Should and could definitely be optimized with a persistant data pool and "meta" allocator | ||
const convertedData: (Data | [Data, Data])[] = []; | ||
const conversions: Promise<void>[] = []; | ||
for (let index = 0; index < data.length; index++) { | ||
const entry = data[index]; | ||
|
||
if (entry.backend.type !== backend.type) { | ||
conversions.push((async () => { | ||
const content = await entry.get(); | ||
const temporaryData = await DataConstructor.from( | ||
backend, | ||
content, | ||
entry.type, | ||
); | ||
convertedData[index] = [entry, temporaryData]; | ||
})()); | ||
} else { | ||
convertedData[index] = entry; | ||
} | ||
} | ||
await Promise.all(conversions); | ||
|
||
const operatorData = convertedData.map((data) => | ||
Array.isArray(data) ? data[1] : data | ||
) as D; | ||
const result = await operator(backend, operatorData, args); | ||
|
||
// Deconvert the converted data to its original form in case all of the data does not share the same backend type | ||
const deconversions: Promise<void>[] = []; | ||
for (const dataPair of convertedData) { | ||
if (Array.isArray(dataPair)) { | ||
deconversions.push((async () => { | ||
const [data, temporaryData] = dataPair; | ||
await data.set(await temporaryData.get()); | ||
temporaryData.dispose(); | ||
})()); | ||
} | ||
} | ||
await Promise.all(deconversions); | ||
|
||
return result; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import { Data, DataType } from "../../types/data.ts"; | ||
import { WasmBackend } from "../../wasm/backend.ts"; | ||
import { Core, operatorOnBackend } from "../core.ts"; | ||
import * as wasm from "../../wasm/operators/binary.ts"; | ||
import * as webgpu from "../../webgpu/operators/binary.ts"; | ||
import { WebGPUBackend } from "../../webgpu/backend.ts"; | ||
import { BackendOperator } from "../../types/backend.ts"; | ||
import { WebGPUData } from "../../webgpu/data.ts"; | ||
import { WasmData } from "../../wasm/data.ts"; | ||
|
||
export function binary<T extends DataType>( | ||
wasmOperator: BackendOperator< | ||
WasmBackend, | ||
[WasmData<T>, WasmData<T>, WasmData<T>], | ||
undefined, | ||
void | ||
>, | ||
webgpuOperator: BackendOperator< | ||
WebGPUBackend, | ||
[WebGPUData<T>, WebGPUData<T>, WebGPUData<T>], | ||
undefined, | ||
Promise<void> | ||
>, | ||
) { | ||
return async function (core: Core, data: [Data<T>, Data<T>, Data<T>]) { | ||
if (Math.max(...data.map(({ length }) => length)) <= 80 * 80) { | ||
return await operatorOnBackend( | ||
core.backends.get("wasm")! as WasmBackend, | ||
wasmOperator, | ||
data, | ||
undefined, | ||
); | ||
} else { | ||
return await operatorOnBackend( | ||
core.backends.get("webgpu")! as WebGPUBackend, | ||
webgpuOperator, | ||
data, | ||
undefined, | ||
); | ||
} | ||
}; | ||
} | ||
|
||
export const add = binary(wasm.add, webgpu.add); | ||
export const sub = binary(wasm.sub, webgpu.sub); | ||
export const mul = binary(wasm.mul, webgpu.mul); | ||
export const div = binary(wasm.div, webgpu.div); | ||
export const mod = binary(wasm.mod, webgpu.mod); | ||
export const min = binary(wasm.min, webgpu.min); | ||
export const max = binary(wasm.max, webgpu.max); | ||
export const prelu = binary(wasm.prelu, webgpu.prelu); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import { Data } from "../../types/data.ts"; | ||
import { WasmBackend } from "../../wasm/backend.ts"; | ||
import { Core, operatorOnBackend } from "../core.ts"; | ||
import { matmul as wasmMatmul } from "../../wasm/operators/matmul.ts"; | ||
import { matmul as webgpuMatmul } from "../../webgpu/operators/matmul.ts"; | ||
import { WebGPUBackend } from "../../webgpu/backend.ts"; | ||
|
||
export async function matmul<T extends "f32" | "u32" | "i32">( | ||
core: Core, | ||
data: [Data<T>, Data<T>, Data<T>], | ||
args: { m: number; n: number; k: number }, | ||
) { | ||
if (Math.max(...data.map(({ length }) => length)) <= 72 * 72) { | ||
return await operatorOnBackend( | ||
core.backends.get("wasm")! as WasmBackend, | ||
wasmMatmul, | ||
data, | ||
args, | ||
); | ||
} else { | ||
return await operatorOnBackend( | ||
core.backends.get("webgpu")! as WebGPUBackend, | ||
webgpuMatmul, | ||
data, | ||
args, | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import { Data } from "../../types/data.ts"; | ||
import { WasmBackend } from "../../wasm/backend.ts"; | ||
import { Core, operatorOnBackend } from "../core.ts"; | ||
import { transpose as wasmTranspose } from "../../wasm/operators/transpose.ts"; | ||
import { transpose as webgpuTranspose } from "../../webgpu/operators/transpose.ts"; | ||
import { WebGPUBackend } from "../../webgpu/backend.ts"; | ||
|
||
export async function transpose<T extends "f32" | "u32" | "i32">( | ||
core: Core, | ||
data: [Data<T>, Data<T>], | ||
args: { w: number; h: number }, | ||
) { | ||
if (Math.max(...data.map(({ length }) => length)) <= 90 * 90) { | ||
return await operatorOnBackend( | ||
core.backends.get("wasm")! as WasmBackend, | ||
wasmTranspose, | ||
data, | ||
args, | ||
); | ||
} else { | ||
return await operatorOnBackend( | ||
core.backends.get("webgpu")! as WebGPUBackend, | ||
webgpuTranspose, | ||
data, | ||
args, | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import { Data, DataType } from "../../types/data.ts"; | ||
import { WasmBackend } from "../../wasm/backend.ts"; | ||
import { Core, operatorOnBackend } from "../core.ts"; | ||
import * as wasm from "../../wasm/operators/unary.ts"; | ||
import * as webgpu from "../../webgpu/operators/unary.ts"; | ||
import { WebGPUBackend } from "../../webgpu/backend.ts"; | ||
import { BackendOperator } from "../../types/backend.ts"; | ||
import { WebGPUData } from "../../webgpu/data.ts"; | ||
import { WasmData } from "../../wasm/data.ts"; | ||
|
||
export function unary<T extends DataType>( | ||
wasmOperator: BackendOperator< | ||
WasmBackend, | ||
[WasmData<T>, WasmData<T>], | ||
undefined, | ||
void | ||
>, | ||
webgpuOperator: BackendOperator< | ||
WebGPUBackend, | ||
[WebGPUData<T>, WebGPUData<T>], | ||
undefined, | ||
Promise<void> | ||
>, | ||
) { | ||
return async function (core: Core, data: [Data<T>, Data<T>]) { | ||
if (Math.max(...data.map(({ length }) => length)) <= 90 * 90) { | ||
return await operatorOnBackend( | ||
core.backends.get("wasm")! as WasmBackend, | ||
wasmOperator, | ||
data, | ||
undefined, | ||
); | ||
} else { | ||
return await operatorOnBackend( | ||
core.backends.get("webgpu")! as WebGPUBackend, | ||
webgpuOperator, | ||
data, | ||
undefined, | ||
); | ||
} | ||
}; | ||
} | ||
|
||
export const abs = unary(wasm.abs, webgpu.abs); | ||
export const linear = unary<"f32" | "u32" | "i32">(wasm.linear, webgpu.linear); | ||
export const neg = unary<"f32" | "i32">(wasm.neg, webgpu.neg); | ||
export const inc = unary<"f32" | "u32" | "i32">(wasm.inc, webgpu.inc); | ||
export const dec = unary<"f32" | "u32" | "i32">(wasm.dec, webgpu.dec); | ||
export const relu = unary<"f32" | "i32">(wasm.relu, webgpu.relu); | ||
export const relu6 = unary<"f32" | "i32">(wasm.relu6, webgpu.relu6); | ||
export const ceil = unary<"f32">(wasm.ceil, webgpu.ceil); | ||
export const floor = unary<"f32">(wasm.floor, webgpu.floor); | ||
export const round = unary<"f32">(wasm.round, webgpu.round); | ||
export const sqrt = unary<"f32">(wasm.sqrt, webgpu.sqrt); | ||
export const rsqrt = unary<"f32">(wasm.rsqrt, webgpu.rsqrt); | ||
export const selu = unary<"f32">(wasm.selu, webgpu.selu); | ||
export const sigmoid = unary<"f32">(wasm.sigmoid, webgpu.sigmoid); | ||
export const square = unary<"f32" | "u32" | "i32">(wasm.square, webgpu.square); | ||
export const cos = unary<"f32">(wasm.cos, webgpu.cos); | ||
export const cosh = unary<"f32">(wasm.cosh, webgpu.cosh); | ||
export const sin = unary<"f32">(wasm.sin, webgpu.sin); | ||
export const sinh = unary<"f32">(wasm.sinh, webgpu.sinh); | ||
export const tan = unary<"f32">(wasm.tan, webgpu.tan); | ||
export const tanh = unary<"f32">(wasm.tanh, webgpu.tanh); | ||
export const exp = unary<"f32">(wasm.exp, webgpu.exp); | ||
export const elu = unary<"f32">(wasm.elu, webgpu.elu); | ||
export const log = unary<"f32">(wasm.log, webgpu.log); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,6 @@ | |
#![feature(core_intrinsics)] | ||
|
||
pub mod binary; | ||
pub mod matmul; | ||
pub mod transpose; | ||
pub mod unary; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.