diff --git a/README.md b/README.md
index c132704a..dcdfe6d3 100644
--- a/README.md
+++ b/README.md
@@ -249,7 +249,7 @@ fn test_matmul_square_matrix() {
|Einsum|12|
|Elu|6, 1|✅|✅|
|Equal|13, 11, 7, 1|✅|
-|Erf|13, 9||✅|
+|Erf|13, 9|✅|✅|
|Exp|13, 6, 1|✅|✅|
|Expand|13, 8|
|EyeLike|9|
diff --git a/wonnx/Cargo.toml b/wonnx/Cargo.toml
index 43a0dbf9..b24f5a53 100644
--- a/wonnx/Cargo.toml
+++ b/wonnx/Cargo.toml
@@ -45,3 +45,4 @@ ndarray = "0.15.4"
approx = "0.5.1"
pollster = "0.3.0"
env_logger = "0.10.0"
+ndarray-rand = "0.14.0"
diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs
index 7d2dea23..009f4a42 100644
--- a/wonnx/src/compiler.rs
+++ b/wonnx/src/compiler.rs
@@ -743,7 +743,7 @@ pub fn compile(
}
}
op @ ("Relu" | "Sigmoid" | "Softsign" | "Softplus" | "Clip" | "Celu" | "Elu"
- | "LeakyRelu" | "HardSigmoid") => {
+ | "LeakyRelu" | "HardSigmoid" | "Erf") => {
let alpha = match op {
"LeakyRelu" => node.get_attribute_value("alpha", Some(0.01))?,
"HardSigmoid" => node.get_attribute_value("alpha", Some(0.2))?,
diff --git a/wonnx/templates/endomorphism/activation.wgsl b/wonnx/templates/endomorphism/activation.wgsl
index 0b9cf649..6b76392d 100644
--- a/wonnx/templates/endomorphism/activation.wgsl
+++ b/wonnx/templates/endomorphism/activation.wgsl
@@ -6,6 +6,8 @@ var input_0: ArrayVector;
@group(0) @binding(1)
var output_0: ArrayVector;
+const pi: f32 = 3.1415;
+
@compute @workgroup_size({{ workgroup_size_x }})
fn main(@builtin(global_invocation_id) global_id: vec3) {
let gidx = global_id.x;
diff --git a/wonnx/templates/snippets/activation_vec.wgsl b/wonnx/templates/snippets/activation_vec.wgsl
index c1e183dc..9327f236 100644
--- a/wonnx/templates/snippets/activation_vec.wgsl
+++ b/wonnx/templates/snippets/activation_vec.wgsl
@@ -46,6 +46,11 @@
{{ activation_output }} = max({{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar()))
+ min({{ scalar_type }}({{ alpha }}) * {{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar()));
+{%- elif activation_type == "Erf" -%}
+ var intermediate = 2.0/sqrt(pi)*({{ activation_input }}+ pow({{activation_input}},vec4(3.0,3.0,3.0,3.0))*0.08943 );
+ intermediate = clamp(intermediate,vec4(-10.0,-10.0,-10.0,-10.0),vec4(10.0,10.0,10.0,10.0));
+ {{ activation_output }} = tanh(intermediate);
+
{%- elif activation_type == "HardSigmoid" -%}
{{ activation_output }} = max(
Vec4(Scalar(), Scalar(), Scalar(), Scalar()),
diff --git a/wonnx/tests/arithmetic.rs b/wonnx/tests/arithmetic.rs
index cdef5c0a..655a380f 100644
--- a/wonnx/tests/arithmetic.rs
+++ b/wonnx/tests/arithmetic.rs
@@ -8,8 +8,38 @@ use wonnx::{
},
};
+use ndarray_rand::rand_distr::Uniform;
+use ndarray_rand::RandomExt;
+
mod common;
+#[test]
+fn test_erf() {
+ let n: usize = 16;
+ let mut input_data = HashMap::new();
+ let data = ndarray::Array1::::random(n, Uniform::new(-100f32, 100f32));
+
+ let shape = vec![n as i64];
+ input_data.insert("X".to_string(), data.as_slice().unwrap().into());
+
+ // Model: X -> Cos -> Y
+ let model = model(graph(
+ vec![tensor("X", &shape)],
+ vec![tensor("Y", &shape)],
+ vec![],
+ vec![],
+ vec![node(vec!["X"], vec!["Y"], "erf", "Erf", vec![])],
+ ));
+
+ let session =
+ pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+ if let OutputTensor::F32(vec) = &result["Y"] {
+ for e in vec {
+ assert_ne!(*e, f32::NAN);
+ }
+ }
+}
#[test]
fn test_cos() {
let n: usize = 16;