Skip to content

Commit

Permalink
fix wasm
Browse files Browse the repository at this point in the history
  • Loading branch information
retraigo committed Sep 24, 2024
1 parent 928a2a5 commit c159b2e
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 45 deletions.
2 changes: 1 addition & 1 deletion crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ safetensors = { workspace = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen = "0.2.92"
getrandom = { version = "0.2", features = ["js"] }
js-sys = "0.3.69"
js-sys = "0.3.69"
40 changes: 21 additions & 19 deletions crates/core/src/cpu/backend.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::time::Instant;

use ndarray::{ArrayD, ArrayViewD, IxDyn};
use safetensors::{serialize, SafeTensors};
Expand All @@ -9,7 +8,7 @@ use crate::{
BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUPostProcessor, CPUScheduler,
Conv2DCPULayer, ConvTensors, ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors,
Dropout1DCPULayer, Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger,
Pool2DCPULayer, PostProcessor, SoftmaxCPULayer, Tensor, Tensors,
Pool2DCPULayer, PostProcessor, SoftmaxCPULayer, Tensor, Tensors, Timer,
};

pub struct Backend {
Expand All @@ -23,10 +22,16 @@ pub struct Backend {
pub optimizer: CPUOptimizer,
pub scheduler: CPUScheduler,
pub logger: Logger,
pub timer: Timer,
}

impl Backend {
pub fn new(config: BackendConfig, logger: Logger, mut tensors: Option<Vec<Tensors>>) -> Self {
pub fn new(
config: BackendConfig,
logger: Logger,
timer: Timer,
mut tensors: Option<Vec<Tensors>>,
) -> Self {
let mut layers = Vec::new();
let mut size = config.size.clone();
for layer in config.layers.iter() {
Expand Down Expand Up @@ -99,6 +104,7 @@ impl Backend {
optimizer,
scheduler,
size,
timer,
}
}

Expand Down Expand Up @@ -147,7 +153,7 @@ impl Backend {
let mut cost = 0f32;
let mut time: u128;
let mut total_time = 0u128;
let start = Instant::now();
let start = (self.timer.now)();
let total_iter = epochs * datasets.len();
while epoch < epochs {
let mut total = 0.0;
Expand All @@ -160,11 +166,11 @@ impl Backend {
let minibatch = outputs.dim()[0];
if !self.silent && ((i + 1) * minibatch) % batches == 0 {
cost = total / (batches) as f32;
time = start.elapsed().as_millis() - total_time;
time = ((self.timer.now)() - start) - total_time;
total_time += time;
let current_iter = epoch * datasets.len() + i;
let msg = format!(
"Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s",
"Epoch={}, Dataset={}, Cost={}, Time={:.3}s, ETA={:.3}s",
epoch,
i * minibatch,
cost,
Expand All @@ -188,25 +194,20 @@ impl Backend {
} else {
disappointments += 1;
if !self.silent {
println!(
(self.logger.log)(format!(
"Patience counter: {} disappointing epochs out of {}.",
disappointments, self.patience
);
));
}
}
if disappointments >= self.patience {
if !self.silent {
println!(
(self.logger.log)(format!(
"No improvement for {} epochs. Stopping early at cost={}",
disappointments, best_cost
);
));
}
let net = Self::load(
&best_net,
Logger {
log: |x| println!("{}", x),
},
);
let net = Self::load(&best_net, self.logger.clone(), self.timer.clone());
self.layers = net.layers;
break;
}
Expand All @@ -225,7 +226,8 @@ impl Backend {
for layer in &mut self.layers {
layer.reset(1);
}
processor.process(self.forward_propagate(data, false, layers))
let res = self.forward_propagate(data, false, layers);
processor.process(res)
}

pub fn save(&self) -> Vec<u8> {
Expand Down Expand Up @@ -278,7 +280,7 @@ impl Backend {
serialize(tensors, &Some(metadata)).unwrap()
}

pub fn load(buffer: &[u8], logger: Logger) -> Self {
pub fn load(buffer: &[u8], logger: Logger, timer: Timer) -> Self {
let tensors = SafeTensors::deserialize(buffer).unwrap();
let (_, metadata) = SafeTensors::read_metadata(buffer).unwrap();
let data = metadata.metadata().as_ref().unwrap();
Expand Down Expand Up @@ -310,6 +312,6 @@ impl Backend {
};
}

Backend::new(config, logger, Some(layers))
Backend::new(config, logger, timer, Some(layers))
}
}
16 changes: 12 additions & 4 deletions crates/core/src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::slice::{from_raw_parts, from_raw_parts_mut};
use std::time::{SystemTime, UNIX_EPOCH};

use crate::{
decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, TrainOptions,
RESOURCES,
decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, Timer,
TrainOptions, RESOURCES,
};

type AllocBufferFn = extern "C" fn(usize) -> *mut u8;
Expand All @@ -11,10 +12,17 @@ fn log(string: String) {
println!("{}", string)
}

fn now() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Your system is behind the Unix Epoch")
.as_millis()
}

#[no_mangle]
pub extern "C" fn ffi_backend_create(ptr: *const u8, len: usize, alloc: AllocBufferFn) -> usize {
let config = decode_json(ptr, len);
let net_backend = Backend::new(config, Logger { log }, None);
let net_backend = Backend::new(config, Logger { log }, Timer { now }, None);
let buf: Vec<u8> = net_backend
.size
.iter()
Expand Down Expand Up @@ -98,7 +106,7 @@ pub extern "C" fn ffi_backend_load(
alloc: AllocBufferFn,
) -> usize {
let buffer = unsafe { from_raw_parts(file_ptr, file_len) };
let net_backend = Backend::load(buffer, Logger { log });
let net_backend = Backend::load(buffer, Logger { log }, Timer { now });
let buf: Vec<u8> = net_backend.size.iter().map(|x| *x as u8).collect();
let size_ptr = alloc(buf.len());
let output_shape = unsafe { from_raw_parts_mut(size_ptr, buf.len()) };
Expand Down
6 changes: 6 additions & 0 deletions crates/core/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ use ndarray::ArrayD;
use safetensors::tensor::TensorView;
use serde::Deserialize;

#[derive(Clone)]
pub struct Logger {
pub log: fn(string: String) -> (),
}

#[derive(Clone)]
pub struct Timer {
pub now: fn() -> u128,
}

pub fn length(shape: Vec<usize>) -> usize {
return shape.iter().fold(1, |i, x| i * x);
}
Expand Down
31 changes: 23 additions & 8 deletions crates/core/src/wasm.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
use js_sys::{Array, Float32Array, Uint8Array};
use ndarray::ArrayD;

use wasm_bindgen::{prelude::wasm_bindgen, JsValue};

use crate::{Backend, Dataset, Logger, PredictOptions, TrainOptions, RESOURCES};
use crate::{Backend, Dataset, Logger, PredictOptions, Timer, TrainOptions, RESOURCES};

#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);
#[wasm_bindgen(js_namespace = Date)]
fn now() -> f64;

}

fn console_log(string: String) {
log(string.as_str())
}

fn performance_now() -> u128 {
now() as u128
}

#[wasm_bindgen]
pub fn wasm_backend_create(config: String, shape: Array) -> usize {
let config = serde_json::from_str(&config).unwrap();
let mut len = 0;
let logger = Logger { log: console_log };
let net_backend = Backend::new(config, logger, None);
let net_backend = Backend::new(
config,
logger,
Timer {
now: performance_now,
},
None,
);
shape.set_length(net_backend.size.len() as u32);
for (i, s) in net_backend.size.iter().enumerate() {
shape.set(i as u32, JsValue::from(*s))
Expand All @@ -37,7 +50,6 @@ pub fn wasm_backend_create(config: String, shape: Array) -> usize {
#[wasm_bindgen]
pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String) {
let options: TrainOptions = serde_json::from_str(&options).unwrap();

let mut datasets = Vec::new();
for i in 0..options.datasets {
let input = buffers[i * 2].to_vec();
Expand All @@ -47,7 +59,6 @@ pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String
outputs: ArrayD::from_shape_vec(options.output_shape.clone(), output).unwrap(),
});
}

RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
backend[id].train(datasets, options.epochs, options.batches, options.rate)
Expand All @@ -59,11 +70,12 @@ pub fn wasm_backend_predict(id: usize, buffer: Float32Array, options: String) ->
let options: PredictOptions = serde_json::from_str(&options).unwrap();
let inputs = ArrayD::from_shape_vec(options.input_shape, buffer.to_vec()).unwrap();

let res = ArrayD::zeros(options.output_shape);
let mut res = ArrayD::zeros(options.output_shape.clone());

RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
let _res = backend[id].predict(inputs, options.layers);
let _res = backend[id].predict(inputs, options.post_process, options.layers);
res.assign(&ArrayD::from_shape_vec(options.output_shape, _res.as_slice().unwrap().to_vec()).unwrap());
});
Float32Array::from(res.as_slice().unwrap())
}
Expand All @@ -82,7 +94,10 @@ pub fn wasm_backend_save(id: usize) -> Uint8Array {
pub fn wasm_backend_load(buffer: Uint8Array, shape: Array) -> usize {
let mut len = 0;
let logger = Logger { log: console_log };
let net_backend = Backend::load(buffer.to_vec().as_slice(), logger);
let timer = Timer {
now: performance_now,
};
let net_backend = Backend::load(buffer.to_vec().as_slice(), logger, timer);
shape.set_length(net_backend.size.len() as u32);
for (i, s) in net_backend.size.iter().enumerate() {
shape.set(i as u32, JsValue::from(*s))
Expand Down
6 changes: 4 additions & 2 deletions examples/classification/binary_iris.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
useSplit,
} from "../../packages/utilities/mod.ts";
import { PostProcess } from "../../packages/core/src/core/api/postprocess.ts";
import { AdamOptimizer, WASM } from "../../mod.ts";

// Define classes
const classes = ["Setosa", "Versicolor"];
Expand All @@ -33,7 +34,7 @@ const y = data.map((fl) => classes.indexOf(fl[4]));
const [train, test] = useSplit({ ratio: [7, 3], shuffle: true }, x, y);

// Setup the CPU backend for Netsaur
await setupBackend(CPU);
await setupBackend(WASM);

// Create a sequential neural network
const net = new Sequential({
Expand All @@ -57,6 +58,7 @@ const net = new Sequential({
],
// We are using Log Loss for finding cost
cost: Cost.BinCrossEntropy,
optimizer: AdamOptimizer()
});

const time = performance.now();
Expand All @@ -70,7 +72,7 @@ net.train(
},
],
// Train for 150 epochs
150,
100,
1,
// Use a smaller learning rate
0.02
Expand Down
34 changes: 28 additions & 6 deletions packages/core/src/backends/wasm/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
wasm_backend_save,
wasm_backend_train,
} from "./lib/netsaur.generated.js";
import type { PostProcessor } from "../../core/api/postprocess.ts";

/**
* Web Assembly Backend.
Expand All @@ -32,7 +33,7 @@ export class WASMBackend implements Backend {
datasets: DataSet[],
epochs: number,
batches: number,
rate: number,
rate: number
): void {
this.outputShape = datasets[0].outputs.shape.slice(1) as Shape<Rank>;
const buffer = [];
Expand All @@ -52,18 +53,39 @@ export class WASMBackend implements Backend {
wasm_backend_train(this.#id, buffer, options);
}

async predict(
input: Tensor<Rank>,
config: { postProcess: PostProcessor; outputShape?: Shape<Rank> }
): Promise<Tensor<Rank>>;
async predict(
input: Tensor<Rank>,
config: { postProcess: PostProcessor; outputShape?: Shape<Rank> },
layers: number[]
): Promise<Tensor<Rank>>;
//deno-lint-ignore require-await
async predict(input: Tensor<Rank>): Promise<Tensor<Rank>> {
async predict(
input: Tensor<Rank>,
config: { postProcess: PostProcessor; outputShape?: Shape<Rank> },
layers?: number[]
): Promise<Tensor<Rank>> {
const options = JSON.stringify({
inputShape: [1, ...input.shape],
outputShape: this.outputShape,
inputShape: input.shape,
outputShape: [input.shape[0], ...(config.outputShape ?? this.outputShape)],
postProcess: config.postProcess,
layers,
} as PredictOptions);
const output = wasm_backend_predict(
this.#id,
input.data as Float32Array,
options,
options
);
return new Tensor(
output,
[
input.shape[0],
...(config.outputShape ?? this.outputShape),
] as Shape<Rank>,
);
return new Tensor(output, this.outputShape!);
}

save(): Uint8Array {
Expand Down
Loading

0 comments on commit c159b2e

Please sign in to comment.