From 0eff1c3d66025e9bb0abb5814bba3bf7d4f9952f Mon Sep 17 00:00:00 2001 From: amine dirhoussi Date: Thu, 21 Dec 2023 10:43:46 +0100 Subject: [PATCH 1/3] erf operator implementation + test --- Cargo.lock | 12 +++++++++++ wonnx/src/compiler.rs | 19 +++++++++++++++++ wonnx/templates/endomorphism/erf.wgsl | 16 ++++++++++++++ wonnx/tests/arithmetic.rs | 30 +++++++++++++++++++++++++++ 4 files changed, 77 insertions(+) create mode 100644 wonnx/templates/endomorphism/erf.wgsl diff --git a/Cargo.lock b/Cargo.lock index 18dbe24c..0c958084 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2071,6 +2071,17 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "ndarray-rand" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65608f937acc725f5b164dcf40f4f0bc5d67dc268ab8a649d3002606718c4588" +dependencies = [ + "ndarray", + "rand", + "rand_distr", +] + [[package]] name = "nom" version = "7.1.3" @@ -4278,6 +4289,7 @@ dependencies = [ "image", "log", "ndarray", + "ndarray-rand", "num", "parking_lot 0.11.2", "pollster", diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index 49126346..b8d02b21 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -42,6 +42,11 @@ fn get_templates() -> &'static Tera { include_str!("../templates/endomorphism/softmax.wgsl"), ) .unwrap(); + tera.add_raw_template( + "endomorphism/erf.wgsl", + include_str!("../templates/endomorphism/erf.wgsl"), + ) + .unwrap(); tera.add_raw_template( "endomorphism/map.wgsl", include_str!("../templates/endomorphism/map.wgsl"), @@ -301,6 +306,20 @@ pub fn compile( } } + "Erf" => { + let (x_threads, workgroup_size_x) = workgroup_size( + ceil(output_lengths[0], 4), + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_X, + )?; + context.insert("workgroup_size_x", &workgroup_size_x); + NodeTemplate { + scalar_type: agreed_type(input_shapes, output_shapes)?, + template: "endomorphism/erf.wgsl", + threads: (x_threads, 1, 1), + } + } + op @ ("ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" | "ReduceProd" | "ReduceL1" | "ReduceL2" | "ReduceLogSum" | "ReduceLogSumExp" | "ReduceSumSquare") => { diff --git a/wonnx/templates/endomorphism/erf.wgsl b/wonnx/templates/endomorphism/erf.wgsl new file mode 100644 index 00000000..4d321e71 --- /dev/null +++ b/wonnx/templates/endomorphism/erf.wgsl @@ -0,0 +1,16 @@ +{%- include "structs.wgsl" -%} +@group(0) @binding(0) +var input_0: ArrayVector; + +const pi: f32 = 3.1415; + +@group(0) @binding(1) +var output_0: ArrayVector; + +@compute @workgroup_size({{ workgroup_size_x }}) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let gidx = global_id.x; + var intermediate = 2.0/sqrt(pi)*(input_0.data[gidx]+ pow(input_0.data[gidx],vec4(3.0,3.0,3.0,3.0))*0.08943 ); + intermediate = clamp(intermediate,vec4(-10.0,-10.0,-10.0,-10.0),vec4(10.0,10.0,10.0,10.0)); + output_0.data[gidx] = tanh(intermediate); +} \ No newline at end of file diff --git a/wonnx/tests/arithmetic.rs b/wonnx/tests/arithmetic.rs index cdef5c0a..057d8852 100644 --- a/wonnx/tests/arithmetic.rs +++ b/wonnx/tests/arithmetic.rs @@ -1,4 +1,5 @@ use approx::assert_abs_diff_eq; +use log::error; use std::{collections::HashMap, convert::TryInto}; use wonnx::{ onnx::TensorProto_DataType, @@ -9,7 +10,36 @@ use wonnx::{ }; mod common; +#[test] +fn test_erf() { + use ndarray_rand::rand_distr::Uniform; + use ndarray_rand::RandomExt; + let n: usize = 16; + let mut input_data = HashMap::new(); + let data = ndarray::Array1::::random(n, Uniform::new(-100f32, 100f32)); + + let shape = vec![n as i64]; + input_data.insert("X".to_string(), data.as_slice().unwrap().into()); + // Model: X -> Cos -> Y + let model = model(graph( + vec![tensor("X", &shape)], + vec![tensor("Y", &shape)], + vec![], + vec![], + vec![node(vec!["X"], vec!["Y"], "erf", "Erf", vec![])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + if let OutputTensor::F32(vec) = &result["Y"] { + for e in vec { + assert_ne!(*e, f32::NAN); + } + } +} #[test] fn test_cos() { let n: usize = 16; From f541bd175409d98145264f48462450824bed9062 Mon Sep 17 00:00:00 2001 From: amine dirhoussi Date: Thu, 21 Dec 2023 10:56:57 +0100 Subject: [PATCH 2/3] updated readme for erf support --- Cargo.lock | 12 ------------ README.md | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0c958084..18dbe24c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2071,17 +2071,6 @@ dependencies = [ "rawpointer", ] -[[package]] -name = "ndarray-rand" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65608f937acc725f5b164dcf40f4f0bc5d67dc268ab8a649d3002606718c4588" -dependencies = [ - "ndarray", - "rand", - "rand_distr", -] - [[package]] name = "nom" version = "7.1.3" @@ -4289,7 +4278,6 @@ dependencies = [ "image", "log", "ndarray", - "ndarray-rand", "num", "parking_lot 0.11.2", "pollster", diff --git a/README.md b/README.md index 5c9b430d..6ed44329 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,7 @@ fn test_matmul_square_matrix() { |Einsum|12| |Elu|6, 1|✅|✅| |Equal|13, 11, 7, 1|✅| -|Erf|13, 9||✅| +|Erf|13, 9|✅|✅| |Exp|13, 6, 1|✅|✅| |Expand|13, 8| |EyeLike|9| From 19906f9a137217782a67482207f9745b825e1209 Mon Sep 17 00:00:00 2001 From: amine dirhoussi Date: Thu, 4 Jan 2024 21:46:19 +0100 Subject: [PATCH 3/3] merged erf implementation in activation_vec --- wonnx/Cargo.toml | 1 + wonnx/src/compiler.rs | 21 +------------------- wonnx/templates/endomorphism/activation.wgsl | 2 ++ wonnx/templates/endomorphism/erf.wgsl | 16 --------------- wonnx/templates/snippets/activation_vec.wgsl | 5 +++++ wonnx/tests/arithmetic.rs | 6 +++--- 6 files changed, 12 insertions(+), 39 deletions(-) delete mode 100644 wonnx/templates/endomorphism/erf.wgsl diff --git a/wonnx/Cargo.toml b/wonnx/Cargo.toml index 4b941f22..fd250d32 100644 --- a/wonnx/Cargo.toml +++ b/wonnx/Cargo.toml @@ -45,3 +45,4 @@ ndarray = "0.15.4" approx = "0.5.1" pollster = "0.3.0" env_logger = "0.10.0" +ndarray-rand = "0.14.0" diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index 41001e7c..009f4a42 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -42,11 +42,6 @@ fn get_templates() -> &'static Tera { include_str!("../templates/endomorphism/softmax.wgsl"), ) .unwrap(); - tera.add_raw_template( - "endomorphism/erf.wgsl", - include_str!("../templates/endomorphism/erf.wgsl"), - ) - .unwrap(); tera.add_raw_template( "endomorphism/map.wgsl", include_str!("../templates/endomorphism/map.wgsl"), @@ -306,20 +301,6 @@ pub fn compile( } } - "Erf" => { - let (x_threads, workgroup_size_x) = workgroup_size( - ceil(output_lengths[0], 4), - MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, - MAX_WORKGROUP_SIZE_X, - )?; - context.insert("workgroup_size_x", &workgroup_size_x); - NodeTemplate { - scalar_type: agreed_type(input_shapes, output_shapes)?, - template: "endomorphism/erf.wgsl", - threads: (x_threads, 1, 1), - } - } - op @ ("ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" | "ReduceProd" | "ReduceL1" | "ReduceL2" | "ReduceLogSum" | "ReduceLogSumExp" | "ReduceSumSquare") => { @@ -762,7 +743,7 @@ pub fn compile( } } op @ ("Relu" | "Sigmoid" | "Softsign" | "Softplus" | "Clip" | "Celu" | "Elu" - | "LeakyRelu" | "HardSigmoid") => { + | "LeakyRelu" | "HardSigmoid" | "Erf") => { let alpha = match op { "LeakyRelu" => node.get_attribute_value("alpha", Some(0.01))?, "HardSigmoid" => node.get_attribute_value("alpha", Some(0.2))?, diff --git a/wonnx/templates/endomorphism/activation.wgsl b/wonnx/templates/endomorphism/activation.wgsl index 0b9cf649..6b76392d 100644 --- a/wonnx/templates/endomorphism/activation.wgsl +++ b/wonnx/templates/endomorphism/activation.wgsl @@ -6,6 +6,8 @@ var input_0: ArrayVector; @group(0) @binding(1) var output_0: ArrayVector; +const pi: f32 = 3.1415; + @compute @workgroup_size({{ workgroup_size_x }}) fn main(@builtin(global_invocation_id) global_id: vec3) { let gidx = global_id.x; diff --git a/wonnx/templates/endomorphism/erf.wgsl b/wonnx/templates/endomorphism/erf.wgsl deleted file mode 100644 index 4d321e71..00000000 --- a/wonnx/templates/endomorphism/erf.wgsl +++ /dev/null @@ -1,16 +0,0 @@ -{%- include "structs.wgsl" -%} -@group(0) @binding(0) -var input_0: ArrayVector; - -const pi: f32 = 3.1415; - -@group(0) @binding(1) -var output_0: ArrayVector; - -@compute @workgroup_size({{ workgroup_size_x }}) -fn main(@builtin(global_invocation_id) global_id: vec3) { - let gidx = global_id.x; - var intermediate = 2.0/sqrt(pi)*(input_0.data[gidx]+ pow(input_0.data[gidx],vec4(3.0,3.0,3.0,3.0))*0.08943 ); - intermediate = clamp(intermediate,vec4(-10.0,-10.0,-10.0,-10.0),vec4(10.0,10.0,10.0,10.0)); - output_0.data[gidx] = tanh(intermediate); -} \ No newline at end of file diff --git a/wonnx/templates/snippets/activation_vec.wgsl b/wonnx/templates/snippets/activation_vec.wgsl index c1e183dc..9327f236 100644 --- a/wonnx/templates/snippets/activation_vec.wgsl +++ b/wonnx/templates/snippets/activation_vec.wgsl @@ -46,6 +46,11 @@ {{ activation_output }} = max({{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar())) + min({{ scalar_type }}({{ alpha }}) * {{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar())); +{%- elif activation_type == "Erf" -%} + var intermediate = 2.0/sqrt(pi)*({{ activation_input }}+ pow({{activation_input}},vec4(3.0,3.0,3.0,3.0))*0.08943 ); + intermediate = clamp(intermediate,vec4(-10.0,-10.0,-10.0,-10.0),vec4(10.0,10.0,10.0,10.0)); + {{ activation_output }} = tanh(intermediate); + {%- elif activation_type == "HardSigmoid" -%} {{ activation_output }} = max( Vec4(Scalar(), Scalar(), Scalar(), Scalar()), diff --git a/wonnx/tests/arithmetic.rs b/wonnx/tests/arithmetic.rs index 057d8852..655a380f 100644 --- a/wonnx/tests/arithmetic.rs +++ b/wonnx/tests/arithmetic.rs @@ -1,5 +1,4 @@ use approx::assert_abs_diff_eq; -use log::error; use std::{collections::HashMap, convert::TryInto}; use wonnx::{ onnx::TensorProto_DataType, @@ -9,11 +8,12 @@ use wonnx::{ }, }; +use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::RandomExt; + mod common; #[test] fn test_erf() { - use ndarray_rand::rand_distr::Uniform; - use ndarray_rand::RandomExt; let n: usize = 16; let mut input_data = HashMap::new(); let data = ndarray::Array1::::random(n, Uniform::new(-100f32, 100f32));