diff --git a/wonnx/Cargo.toml b/wonnx/Cargo.toml index 4b941f22..fd250d32 100644 --- a/wonnx/Cargo.toml +++ b/wonnx/Cargo.toml @@ -45,3 +45,4 @@ ndarray = "0.15.4" approx = "0.5.1" pollster = "0.3.0" env_logger = "0.10.0" +ndarray-rand = "0.14.0" diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index 41001e7c..009f4a42 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -42,11 +42,6 @@ fn get_templates() -> &'static Tera { include_str!("../templates/endomorphism/softmax.wgsl"), ) .unwrap(); - tera.add_raw_template( - "endomorphism/erf.wgsl", - include_str!("../templates/endomorphism/erf.wgsl"), - ) - .unwrap(); tera.add_raw_template( "endomorphism/map.wgsl", include_str!("../templates/endomorphism/map.wgsl"), @@ -306,20 +301,6 @@ pub fn compile( } } - "Erf" => { - let (x_threads, workgroup_size_x) = workgroup_size( - ceil(output_lengths[0], 4), - MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, - MAX_WORKGROUP_SIZE_X, - )?; - context.insert("workgroup_size_x", &workgroup_size_x); - NodeTemplate { - scalar_type: agreed_type(input_shapes, output_shapes)?, - template: "endomorphism/erf.wgsl", - threads: (x_threads, 1, 1), - } - } - op @ ("ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" | "ReduceProd" | "ReduceL1" | "ReduceL2" | "ReduceLogSum" | "ReduceLogSumExp" | "ReduceSumSquare") => { @@ -762,7 +743,7 @@ pub fn compile( } } op @ ("Relu" | "Sigmoid" | "Softsign" | "Softplus" | "Clip" | "Celu" | "Elu" - | "LeakyRelu" | "HardSigmoid") => { + | "LeakyRelu" | "HardSigmoid" | "Erf") => { let alpha = match op { "LeakyRelu" => node.get_attribute_value("alpha", Some(0.01))?, "HardSigmoid" => node.get_attribute_value("alpha", Some(0.2))?, diff --git a/wonnx/templates/endomorphism/activation.wgsl b/wonnx/templates/endomorphism/activation.wgsl index 0b9cf649..6b76392d 100644 --- a/wonnx/templates/endomorphism/activation.wgsl +++ b/wonnx/templates/endomorphism/activation.wgsl @@ -6,6 +6,8 @@ var input_0: ArrayVector; @group(0) @binding(1) var output_0: ArrayVector; +const pi: f32 = 3.1415; + @compute @workgroup_size({{ workgroup_size_x }}) fn main(@builtin(global_invocation_id) global_id: vec3) { let gidx = global_id.x; diff --git a/wonnx/templates/endomorphism/erf.wgsl b/wonnx/templates/endomorphism/erf.wgsl deleted file mode 100644 index 4d321e71..00000000 --- a/wonnx/templates/endomorphism/erf.wgsl +++ /dev/null @@ -1,16 +0,0 @@ -{%- include "structs.wgsl" -%} -@group(0) @binding(0) -var input_0: ArrayVector; - -const pi: f32 = 3.1415; - -@group(0) @binding(1) -var output_0: ArrayVector; - -@compute @workgroup_size({{ workgroup_size_x }}) -fn main(@builtin(global_invocation_id) global_id: vec3) { - let gidx = global_id.x; - var intermediate = 2.0/sqrt(pi)*(input_0.data[gidx]+ pow(input_0.data[gidx],vec4(3.0,3.0,3.0,3.0))*0.08943 ); - intermediate = clamp(intermediate,vec4(-10.0,-10.0,-10.0,-10.0),vec4(10.0,10.0,10.0,10.0)); - output_0.data[gidx] = tanh(intermediate); -} \ No newline at end of file diff --git a/wonnx/templates/snippets/activation_vec.wgsl b/wonnx/templates/snippets/activation_vec.wgsl index c1e183dc..9327f236 100644 --- a/wonnx/templates/snippets/activation_vec.wgsl +++ b/wonnx/templates/snippets/activation_vec.wgsl @@ -46,6 +46,11 @@ {{ activation_output }} = max({{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar())) + min({{ scalar_type }}({{ alpha }}) * {{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar())); +{%- elif activation_type == "Erf" -%} + var intermediate = 2.0/sqrt(pi)*({{ activation_input }}+ pow({{activation_input}},vec4(3.0,3.0,3.0,3.0))*0.08943 ); + intermediate = clamp(intermediate,vec4(-10.0,-10.0,-10.0,-10.0),vec4(10.0,10.0,10.0,10.0)); + {{ activation_output }} = tanh(intermediate); + {%- elif activation_type == "HardSigmoid" -%} {{ activation_output }} = max( Vec4(Scalar(), Scalar(), Scalar(), Scalar()), diff --git a/wonnx/tests/arithmetic.rs b/wonnx/tests/arithmetic.rs index 057d8852..655a380f 100644 --- a/wonnx/tests/arithmetic.rs +++ b/wonnx/tests/arithmetic.rs @@ -1,5 +1,4 @@ use approx::assert_abs_diff_eq; -use log::error; use std::{collections::HashMap, convert::TryInto}; use wonnx::{ onnx::TensorProto_DataType, @@ -9,11 +8,12 @@ use wonnx::{ }, }; +use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::RandomExt; + mod common; #[test] fn test_erf() { - use ndarray_rand::rand_distr::Uniform; - use ndarray_rand::RandomExt; let n: usize = 16; let mut input_data = HashMap::new(); let data = ndarray::Array1::::random(n, Uniform::new(-100f32, 100f32));