diff --git a/README.md b/README.md index c132704a..dcdfe6d3 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,7 @@ fn test_matmul_square_matrix() { |Einsum|12| |Elu|6, 1|✅|✅| |Equal|13, 11, 7, 1|✅| -|Erf|13, 9||✅| +|Erf|13, 9|✅|✅| |Exp|13, 6, 1|✅|✅| |Expand|13, 8| |EyeLike|9| diff --git a/wonnx/Cargo.toml b/wonnx/Cargo.toml index 43a0dbf9..b24f5a53 100644 --- a/wonnx/Cargo.toml +++ b/wonnx/Cargo.toml @@ -45,3 +45,4 @@ ndarray = "0.15.4" approx = "0.5.1" pollster = "0.3.0" env_logger = "0.10.0" +ndarray-rand = "0.14.0" diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index 7d2dea23..009f4a42 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -743,7 +743,7 @@ pub fn compile( } } op @ ("Relu" | "Sigmoid" | "Softsign" | "Softplus" | "Clip" | "Celu" | "Elu" - | "LeakyRelu" | "HardSigmoid") => { + | "LeakyRelu" | "HardSigmoid" | "Erf") => { let alpha = match op { "LeakyRelu" => node.get_attribute_value("alpha", Some(0.01))?, "HardSigmoid" => node.get_attribute_value("alpha", Some(0.2))?, diff --git a/wonnx/templates/endomorphism/activation.wgsl b/wonnx/templates/endomorphism/activation.wgsl index 0b9cf649..6b76392d 100644 --- a/wonnx/templates/endomorphism/activation.wgsl +++ b/wonnx/templates/endomorphism/activation.wgsl @@ -6,6 +6,8 @@ var input_0: ArrayVector; @group(0) @binding(1) var output_0: ArrayVector; +const pi: f32 = 3.1415; + @compute @workgroup_size({{ workgroup_size_x }}) fn main(@builtin(global_invocation_id) global_id: vec3) { let gidx = global_id.x; diff --git a/wonnx/templates/snippets/activation_vec.wgsl b/wonnx/templates/snippets/activation_vec.wgsl index c1e183dc..9327f236 100644 --- a/wonnx/templates/snippets/activation_vec.wgsl +++ b/wonnx/templates/snippets/activation_vec.wgsl @@ -46,6 +46,11 @@ {{ activation_output }} = max({{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar())) + min({{ scalar_type }}({{ alpha }}) * {{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar())); +{%- elif activation_type == "Erf" -%} + var intermediate = 2.0/sqrt(pi)*({{ activation_input }}+ pow({{activation_input}},vec4(3.0,3.0,3.0,3.0))*0.08943 ); + intermediate = clamp(intermediate,vec4(-10.0,-10.0,-10.0,-10.0),vec4(10.0,10.0,10.0,10.0)); + {{ activation_output }} = tanh(intermediate); + {%- elif activation_type == "HardSigmoid" -%} {{ activation_output }} = max( Vec4(Scalar(), Scalar(), Scalar(), Scalar()), diff --git a/wonnx/tests/arithmetic.rs b/wonnx/tests/arithmetic.rs index cdef5c0a..655a380f 100644 --- a/wonnx/tests/arithmetic.rs +++ b/wonnx/tests/arithmetic.rs @@ -8,8 +8,38 @@ use wonnx::{ }, }; +use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::RandomExt; + mod common; +#[test] +fn test_erf() { + let n: usize = 16; + let mut input_data = HashMap::new(); + let data = ndarray::Array1::::random(n, Uniform::new(-100f32, 100f32)); + + let shape = vec![n as i64]; + input_data.insert("X".to_string(), data.as_slice().unwrap().into()); + + // Model: X -> Cos -> Y + let model = model(graph( + vec![tensor("X", &shape)], + vec![tensor("Y", &shape)], + vec![], + vec![], + vec![node(vec!["X"], vec!["Y"], "erf", "Erf", vec![])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + if let OutputTensor::F32(vec) = &result["Y"] { + for e in vec { + assert_ne!(*e, f32::NAN); + } + } +} #[test] fn test_cos() { let n: usize = 16;