Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add postprocessing (and fix WASM) #65

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ safetensors = { workspace = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen = "0.2.92"
getrandom = { version = "0.2", features = ["js"] }
js-sys = "0.3.69"
js-sys = "0.3.69"
54 changes: 31 additions & 23 deletions crates/core/src/cpu/backend.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use std::collections::HashMap;
use std::time::Instant;

use ndarray::{ArrayD, ArrayViewD, IxDyn};
use safetensors::{serialize, SafeTensors};

use crate::{
to_arr, ActivationCPULayer, BackendConfig, BatchNorm1DCPULayer, BatchNorm2DCPULayer,
BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUScheduler, Conv2DCPULayer, ConvTensors,
ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors, Dropout1DCPULayer,
Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger, Pool2DCPULayer, SoftmaxCPULayer,
Tensor, Tensors,
BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUPostProcessor, CPUScheduler,
Conv2DCPULayer, ConvTensors, ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors,
Dropout1DCPULayer, Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger,
Pool2DCPULayer, PostProcessor, SoftmaxCPULayer, Tensor, Tensors, Timer,
};

pub struct Backend {
Expand All @@ -23,10 +22,16 @@ pub struct Backend {
pub optimizer: CPUOptimizer,
pub scheduler: CPUScheduler,
pub logger: Logger,
pub timer: Timer,
}

impl Backend {
pub fn new(config: BackendConfig, logger: Logger, mut tensors: Option<Vec<Tensors>>) -> Self {
pub fn new(
config: BackendConfig,
logger: Logger,
timer: Timer,
mut tensors: Option<Vec<Tensors>>,
) -> Self {
let mut layers = Vec::new();
let mut size = config.size.clone();
for layer in config.layers.iter() {
Expand Down Expand Up @@ -99,6 +104,7 @@ impl Backend {
optimizer,
scheduler,
size,
timer,
}
}

Expand Down Expand Up @@ -147,7 +153,7 @@ impl Backend {
let mut cost = 0f32;
let mut time: u128;
let mut total_time = 0u128;
let start = Instant::now();
let start = (self.timer.now)();
let total_iter = epochs * datasets.len();
while epoch < epochs {
let mut total = 0.0;
Expand All @@ -160,11 +166,11 @@ impl Backend {
let minibatch = outputs.dim()[0];
if !self.silent && ((i + 1) * minibatch) % batches == 0 {
cost = total / (batches) as f32;
time = start.elapsed().as_millis() - total_time;
time = ((self.timer.now)() - start) - total_time;
total_time += time;
let current_iter = epoch * datasets.len() + i;
let msg = format!(
"Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s",
"Epoch={}, Dataset={}, Cost={}, Time={:.3}s, ETA={:.3}s",
epoch,
i * minibatch,
cost,
Expand All @@ -188,25 +194,20 @@ impl Backend {
} else {
disappointments += 1;
if !self.silent {
println!(
(self.logger.log)(format!(
"Patience counter: {} disappointing epochs out of {}.",
disappointments, self.patience
);
));
}
}
if disappointments >= self.patience {
if !self.silent {
println!(
(self.logger.log)(format!(
"No improvement for {} epochs. Stopping early at cost={}",
disappointments, best_cost
);
));
}
let net = Self::load(
&best_net,
Logger {
log: |x| println!("{}", x),
},
);
let net = Self::load(&best_net, self.logger.clone(), self.timer.clone());
self.layers = net.layers;
break;
}
Expand All @@ -215,11 +216,18 @@ impl Backend {
}
}

pub fn predict(&mut self, data: ArrayD<f32>, layers: Option<Vec<usize>>) -> ArrayD<f32> {
pub fn predict(
&mut self,
data: ArrayD<f32>,
postprocess: PostProcessor,
layers: Option<Vec<usize>>,
) -> ArrayD<f32> {
let processor = CPUPostProcessor::from(&postprocess);
for layer in &mut self.layers {
layer.reset(1);
}
self.forward_propagate(data, false, layers)
let res = self.forward_propagate(data, false, layers);
processor.process(res)
}

pub fn save(&self) -> Vec<u8> {
Expand Down Expand Up @@ -272,7 +280,7 @@ impl Backend {
serialize(tensors, &Some(metadata)).unwrap()
}

pub fn load(buffer: &[u8], logger: Logger) -> Self {
pub fn load(buffer: &[u8], logger: Logger, timer: Timer) -> Self {
let tensors = SafeTensors::deserialize(buffer).unwrap();
let (_, metadata) = SafeTensors::read_metadata(buffer).unwrap();
let data = metadata.metadata().as_ref().unwrap();
Expand Down Expand Up @@ -304,6 +312,6 @@ impl Backend {
};
}

Backend::new(config, logger, Some(layers))
Backend::new(config, logger, timer, Some(layers))
}
}
4 changes: 3 additions & 1 deletion crates/core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod layers;
mod optimizers;
mod schedulers;
mod regularizer;
mod postprocessing;

pub use activation::*;
pub use backend::*;
Expand All @@ -14,4 +15,5 @@ pub use init::*;
pub use layers::*;
pub use optimizers::*;
pub use schedulers::*;
pub use regularizer::*;
pub use regularizer::*;
pub use postprocessing::*;
28 changes: 28 additions & 0 deletions crates/core/src/cpu/postprocessing/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use ndarray::ArrayD;
use crate::PostProcessor;

mod step;
use step::CPUStepFunction;

pub enum CPUPostProcessor {
None,
Sign,
Step(CPUStepFunction),
}

impl CPUPostProcessor {
pub fn from(processor: &PostProcessor) -> Self {
match processor {
PostProcessor::None => CPUPostProcessor::None,
PostProcessor::Sign => CPUPostProcessor::Sign,
PostProcessor::Step(config) => CPUPostProcessor::Step(CPUStepFunction::new(config)),
}
}
pub fn process(&self, x: ArrayD<f32>) -> ArrayD<f32> {
match self {
CPUPostProcessor::None => x,
CPUPostProcessor::Sign => x.map(|y| y.signum()),
CPUPostProcessor::Step(processor) => x.map(|y| processor.step(*y)),
}
}
}
22 changes: 22 additions & 0 deletions crates/core/src/cpu/postprocessing/step.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::StepFunctionConfig;

pub struct CPUStepFunction {
thresholds: Vec<f32>,
values: Vec<f32>
}
impl CPUStepFunction {
pub fn new(config: &StepFunctionConfig) -> Self {
return Self {
thresholds: config.thresholds.clone(),
values: config.values.clone()
}
}
pub fn step(&self, x: f32) -> f32 {
for (i, &threshold) in self.thresholds.iter().enumerate() {
if x < threshold {
return self.values[i];
}
}
return self.values.last().unwrap().clone()
}
}
18 changes: 13 additions & 5 deletions crates/core/src/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::slice::{from_raw_parts, from_raw_parts_mut};
use std::time::{SystemTime, UNIX_EPOCH};

use crate::{
decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, TrainOptions,
RESOURCES,
decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, Timer,
TrainOptions, RESOURCES,
};

type AllocBufferFn = extern "C" fn(usize) -> *mut u8;
Expand All @@ -11,10 +12,17 @@ fn log(string: String) {
println!("{}", string)
}

fn now() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Your system is behind the Unix Epoch")
.as_millis()
}

#[no_mangle]
pub extern "C" fn ffi_backend_create(ptr: *const u8, len: usize, alloc: AllocBufferFn) -> usize {
let config = decode_json(ptr, len);
let net_backend = Backend::new(config, Logger { log }, None);
let net_backend = Backend::new(config, Logger { log }, Timer { now }, None);
let buf: Vec<u8> = net_backend
.size
.iter()
Expand Down Expand Up @@ -75,7 +83,7 @@ pub extern "C" fn ffi_backend_predict(

RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
let res = backend[id].predict(inputs, options.layers);
let res = backend[id].predict(inputs, options.post_process, options.layers);
outputs.copy_from_slice(res.as_slice().unwrap());
});
}
Expand All @@ -98,7 +106,7 @@ pub extern "C" fn ffi_backend_load(
alloc: AllocBufferFn,
) -> usize {
let buffer = unsafe { from_raw_parts(file_ptr, file_len) };
let net_backend = Backend::load(buffer, Logger { log });
let net_backend = Backend::load(buffer, Logger { log }, Timer { now });
let buf: Vec<u8> = net_backend.size.iter().map(|x| *x as u8).collect();
let size_ptr = alloc(buf.len());
let output_shape = unsafe { from_raw_parts_mut(size_ptr, buf.len()) };
Expand Down
16 changes: 16 additions & 0 deletions crates/core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ pub enum Scheduler {
OneCycle(OneCycleScheduler),
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StepFunctionConfig {
pub thresholds: Vec<f32>,
pub values: Vec<f32>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type", content = "config")]
#[serde(rename_all = "lowercase")]
pub enum PostProcessor {
None,
Sign,
Step(StepFunctionConfig),
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct TrainOptions {
Expand All @@ -212,6 +227,7 @@ pub struct PredictOptions {
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
pub layers: Option<Vec<usize>>,
pub post_process: PostProcessor,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
6 changes: 6 additions & 0 deletions crates/core/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ use ndarray::ArrayD;
use safetensors::tensor::TensorView;
use serde::Deserialize;

#[derive(Clone)]
pub struct Logger {
pub log: fn(string: String) -> (),
}

#[derive(Clone)]
pub struct Timer {
pub now: fn() -> u128,
}

pub fn length(shape: Vec<usize>) -> usize {
return shape.iter().fold(1, |i, x| i * x);
}
Expand Down
31 changes: 23 additions & 8 deletions crates/core/src/wasm.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
use js_sys::{Array, Float32Array, Uint8Array};
use ndarray::ArrayD;

use wasm_bindgen::{prelude::wasm_bindgen, JsValue};

use crate::{Backend, Dataset, Logger, PredictOptions, TrainOptions, RESOURCES};
use crate::{Backend, Dataset, Logger, PredictOptions, Timer, TrainOptions, RESOURCES};

#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);
#[wasm_bindgen(js_namespace = Date)]
fn now() -> f64;

}

fn console_log(string: String) {
log(string.as_str())
}

fn performance_now() -> u128 {
now() as u128
}

#[wasm_bindgen]
pub fn wasm_backend_create(config: String, shape: Array) -> usize {
let config = serde_json::from_str(&config).unwrap();
let mut len = 0;
let logger = Logger { log: console_log };
let net_backend = Backend::new(config, logger, None);
let net_backend = Backend::new(
config,
logger,
Timer {
now: performance_now,
},
None,
);
shape.set_length(net_backend.size.len() as u32);
for (i, s) in net_backend.size.iter().enumerate() {
shape.set(i as u32, JsValue::from(*s))
Expand All @@ -37,7 +50,6 @@ pub fn wasm_backend_create(config: String, shape: Array) -> usize {
#[wasm_bindgen]
pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String) {
let options: TrainOptions = serde_json::from_str(&options).unwrap();

let mut datasets = Vec::new();
for i in 0..options.datasets {
let input = buffers[i * 2].to_vec();
Expand All @@ -47,7 +59,6 @@ pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String
outputs: ArrayD::from_shape_vec(options.output_shape.clone(), output).unwrap(),
});
}

RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
backend[id].train(datasets, options.epochs, options.batches, options.rate)
Expand All @@ -59,11 +70,12 @@ pub fn wasm_backend_predict(id: usize, buffer: Float32Array, options: String) ->
let options: PredictOptions = serde_json::from_str(&options).unwrap();
let inputs = ArrayD::from_shape_vec(options.input_shape, buffer.to_vec()).unwrap();

let res = ArrayD::zeros(options.output_shape);
let mut res = ArrayD::zeros(options.output_shape.clone());

RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
let _res = backend[id].predict(inputs, options.layers);
let _res = backend[id].predict(inputs, options.post_process, options.layers);
res.assign(&ArrayD::from_shape_vec(options.output_shape, _res.as_slice().unwrap().to_vec()).unwrap());
});
Float32Array::from(res.as_slice().unwrap())
}
Expand All @@ -82,7 +94,10 @@ pub fn wasm_backend_save(id: usize) -> Uint8Array {
pub fn wasm_backend_load(buffer: Uint8Array, shape: Array) -> usize {
let mut len = 0;
let logger = Logger { log: console_log };
let net_backend = Backend::load(buffer.to_vec().as_slice(), logger);
let timer = Timer {
now: performance_now,
};
let net_backend = Backend::load(buffer.to_vec().as_slice(), logger, timer);
shape.set_length(net_backend.size.len() as u32);
for (i, s) in net_backend.size.iter().enumerate() {
shape.set(i as u32, JsValue::from(*s))
Expand Down
Loading
Loading