Skip to content

Commit

Permalink
fix: Update WGPU (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinWeiTan authored Aug 20, 2022
1 parent 1a50eae commit 11fdf1e
Show file tree
Hide file tree
Showing 26 changed files with 280 additions and 749 deletions.
5 changes: 3 additions & 2 deletions backend/webgpu/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export class WebGPUBackend implements Backend {
const module = this.device.createShaderModule({ code });
const pipeline = await this.device.createComputePipelineAsync({
compute: { module, entryPoint: "main" },
layout: "auto",
});
const layout = pipeline.getBindGroupLayout(0);

Expand Down Expand Up @@ -75,8 +76,8 @@ export class WebGPUBackend implements Backend {
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setBindGroup(0, bindgroup);
passEncoder.setPipeline(pipeline);
passEncoder.dispatch(...workgroups as [number, number, number]);
passEncoder.endPass();
passEncoder.dispatchWorkgroups(...workgroups as [number, number, number]);
passEncoder.end();

this.device.queue.submit([commandEncoder.finish()]);
}
Expand Down
46 changes: 23 additions & 23 deletions backend/webgpu/shaders/binary.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,34 @@ import { DataType } from "../../types/data.ts";

export default (type: DataType, expr: string) => `
// #import "prelude.wgsl"
var<private> e: f32 = 2.718281828459045;
var<private> pi: f32 = 3.141592653589793;
var<private> tau: f32 = 6.283185307179586;
var<private> phi: f32 = 1.618033988749895;
var<private> feigd: f32 = 4.66920160910299;
var<private> feiga: f32 = -2.5029078750958926;
var<private> e: f32 = 2.718281828459045;
var<private> pi: f32 = 3.141592653589793;
var<private> tau: f32 = 6.283185307179586;
var<private> phi: f32 = 1.618033988749895;
var<private> feigd: f32 = 4.66920160910299;
var<private> feiga: f32 = -2.5029078750958926;
var<private> gauss: f32 = 0.8346268416740732;
// #input type: DataType
// #input expr: string
// #input type: DataType
// #input expr: string
struct Data {
values: array<${type}>;
};
struct Data {
values: array<${type}>,
};
[[group(0), binding(0)]]
var<storage, read> a_data: Data;
[[group(0), binding(1)]]
var<storage, read> b_data: Data;
[[group(0), binding(2)]]
var<storage, write> c_data: Data;
@group(0) @binding(0)
var<storage, read> a_data: Data;
@group(0) @binding(1)
var<storage, read> b_data: Data;
@group(0) @binding(2)
var<storage, write> c_data: Data;
fn binary(a: ${type}, b: ${type}) -> ${type} {
${expr}
}
fn binary(a: ${type}, b: ${type}) -> ${type} {
${expr}
}
[[stage(compute), workgroup_size(128)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
c_data.values[global_id.x] = binary(a_data.values[global_id.x], b_data.values[global_id.x]);
@compute @workgroup_size(128)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
c_data.values[global_id.x] = binary(a_data.values[global_id.x], b_data.values[global_id.x]);
}
`;
12 changes: 6 additions & 6 deletions backend/webgpu/shaders/binary.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
#input expr: string

struct Data {
values: array<type>;
values: array<type>,
};

[[group(0), binding(0)]]
@group(0) @binding(0)
var<storage, read> a_data: Data;
[[group(0), binding(1)]]
@group(0) @binding(1)
var<storage, read> b_data: Data;
[[group(0), binding(2)]]
@group(0) @binding(2)
var<storage, write> c_data: Data;

fn binary(a: type, b: type) -> type {
expr
}

[[stage(compute), workgroup_size(128)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
@compute @workgroup_size(128)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
c_data.values[global_id.x] = binary(a_data.values[global_id.x], b_data.values[global_id.x]);
}
76 changes: 38 additions & 38 deletions backend/webgpu/shaders/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,45 @@ import { DataType } from "../../types/data.ts";

export default (type: DataType) => `
// #import "prelude.wgsl"
var<private> e: f32 = 2.718281828459045;
var<private> pi: f32 = 3.141592653589793;
var<private> tau: f32 = 6.283185307179586;
var<private> phi: f32 = 1.618033988749895;
var<private> feigd: f32 = 4.66920160910299;
var<private> feiga: f32 = -2.5029078750958926;
var<private> e: f32 = 2.718281828459045;
var<private> pi: f32 = 3.141592653589793;
var<private> tau: f32 = 6.283185307179586;
var<private> phi: f32 = 1.618033988749895;
var<private> feigd: f32 = 4.66920160910299;
var<private> feiga: f32 = -2.5029078750958926;
var<private> gauss: f32 = 0.8346268416740732;
// #input type: DataType
struct Uniform {
m: u32;
n: u32;
k: u32;
};
struct Data {
values: array<${type}>;
};
[[group(0), binding(0)]]
var<storage, read> a: Data;
[[group(0), binding(1)]]
var<storage, read> b: Data;
[[group(0), binding(2)]]
var<storage, write> c: Data;
[[group(0), binding(3)]]
var<uniform> meta: Uniform;
[[stage(compute), workgroup_size(8, 8, 1)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
if (global_id.x >= meta.n || global_id.y >= meta.m) {
return;
}
var sum = ${type}(0);
for (var k = 0u; k < meta.k; k = k + 1u) {
sum = sum + a.values[global_id.y * meta.k + k] * b.values[k * meta.n + global_id.x];
}
c.values[global_id.x + global_id.y * meta.n] = sum;
// #input type: DataType
struct Uniforms {
m: u32,
n: u32,
k: u32,
};
struct Data {
values: array<${type}>,
};
@group(0) @binding(0)
var<storage, read> a: Data;
@group(0) @binding(1)
var<storage, read> b: Data;
@group(0) @binding(2)
var<storage, write> c: Data;
@group(0) @binding(3)
var<uniform> uniforms: Uniforms;
@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
if (global_id.x >= uniforms.n || global_id.y >= uniforms.m) {
return;
}
var sum = ${type}(0);
for (var k = 0u; k < uniforms.k; k = k + 1u) {
sum = sum + a.values[global_id.y * uniforms.k + k] * b.values[k * uniforms.n + global_id.x];
}
c.values[global_id.x + global_id.y * uniforms.n] = sum;
}
`;
32 changes: 16 additions & 16 deletions backend/webgpu/shaders/matmul.wgsl
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
#import "prelude.wgsl"
#input type: DataType

struct Uniform {
m: u32;
n: u32;
k: u32;
struct Uniforms {
m: u32,
n: u32,
k: u32,
};

struct Data {
values: array<type>;
values: array<type>,
};

[[group(0), binding(0)]]
@group(0) @binding(0)
var<storage, read> a: Data;
[[group(0), binding(1)]]
@group(0) @binding(1)
var<storage, read> b: Data;
[[group(0), binding(2)]]
@group(0) @binding(2)
var<storage, write> c: Data;
[[group(0), binding(3)]]
var<uniform> meta: Uniform;
@group(0) @binding(3)
var<uniform> uniforms: Uniforms;

[[stage(compute), workgroup_size(8, 8, 1)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
if (global_id.x >= meta.n || global_id.y >= meta.m) {
@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
if (global_id.x >= uniforms.n || global_id.y >= uniforms.m) {
return;
}

var sum = type(0);
for (var k = 0u; k < meta.k; k = k + 1u) {
sum = sum + a.values[global_id.y * meta.k + k] * b.values[k * meta.n + global_id.x];
for (var k = 0u; k < uniforms.k; k = k + 1u) {
sum = sum + a.values[global_id.y * uniforms.k + k] * b.values[k * uniforms.n + global_id.x];
}
c.values[global_id.x + global_id.y * meta.n] = sum;
c.values[global_id.x + global_id.y * uniforms.n] = sum;
}
62 changes: 31 additions & 31 deletions backend/webgpu/shaders/pad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,39 @@ import { DataType } from "../../types/data.ts";

export default (type: DataType) => `
// #import "prelude.wgsl"
var<private> e: f32 = 2.718281828459045;
var<private> pi: f32 = 3.141592653589793;
var<private> tau: f32 = 6.283185307179586;
var<private> phi: f32 = 1.618033988749895;
var<private> feigd: f32 = 4.66920160910299;
var<private> feiga: f32 = -2.5029078750958926;
var<private> e: f32 = 2.718281828459045;
var<private> pi: f32 = 3.141592653589793;
var<private> tau: f32 = 6.283185307179586;
var<private> phi: f32 = 1.618033988749895;
var<private> feigd: f32 = 4.66920160910299;
var<private> feiga: f32 = -2.5029078750958926;
var<private> gauss: f32 = 0.8346268416740732;
// #input type: DataType
struct Uniform {
w: u32;
h: u32;
n: u32;
};
struct Data {
values: array<${type}>;
};
[[group(0), binding(0)]]
var<storage, read> a: Data;
[[group(0), binding(1)]]
var<storage, write> b: Data;
[[group(0), binding(2)]]
var<uniform> meta: Uniform;
[[stage(compute), workgroup_size(8, 8, 1)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
if (global_id.x >= meta.w || global_id.y >= meta.h) {
return;
}
// #input type: DataType
struct Uniforms {
w: u32,
h: u32,
n: u32,
};
struct Data {
values: array<${type}>,
};
@group(0) @binding(0)
var<storage, read> a: Data;
@group(0) @binding(1)
var<storage, write> b: Data;
@group(0) @binding(2)
var<uniform> uniforms: Uniforms;
@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
if (global_id.x >= uniforms.w || global_id.y >= uniforms.h) {
return;
}
b.values[global_id.x + global_id.y * meta.n] = a.values[global_id.x + global_id.y * meta.w];
b.values[global_id.x + global_id.y * uniforms.n] = a.values[global_id.x + global_id.y * uniforms.w];
}
`;
26 changes: 13 additions & 13 deletions backend/webgpu/shaders/pad.wgsl
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
#import "prelude.wgsl"
#input type: DataType

struct Uniform {
w: u32;
h: u32;
n: u32;
struct Uniforms {
w: u32,
h: u32,
n: u32,
};

struct Data {
values: array<type>;
values: array<type>,
};

[[group(0), binding(0)]]
@group(0) @binding(0)
var<storage, read> a: Data;
[[group(0), binding(1)]]
@group(0) @binding(1)
var<storage, write> b: Data;
[[group(0), binding(2)]]
var<uniform> meta: Uniform;
@group(0) @binding(2)
var<uniform> uniforms: Uniforms;

[[stage(compute), workgroup_size(8, 8, 1)]]
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
if (global_id.x >= meta.w || global_id.y >= meta.h) {
@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
if (global_id.x >= uniforms.w || global_id.y >= uniforms.h) {
return;
}

b.values[global_id.x + global_id.y * meta.n] = a.values[global_id.x + global_id.y * meta.w];
b.values[global_id.x + global_id.y * uniforms.n] = a.values[global_id.x + global_id.y * uniforms.w];
}
Loading

0 comments on commit 11fdf1e

Please sign in to comment.