Skip to content

Commit

Permalink
Merge pull request #185 from ariaghora/feature/hardsigmoid
Browse files Browse the repository at this point in the history
Integration of HardSigmoid Operation
  • Loading branch information
pixelspark authored Jan 1, 2024
2 parents 0c56190 + 95e7b38 commit 2b72b51
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ fn test_matmul_square_matrix() {
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalMaxPool">GlobalMaxPool</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GlobalMaxPool-1">1</a>|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater">Greater</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-9">9</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-7">7</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-1">1</a>||
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#GridSample">GridSample</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GridSample-16">16</a>|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid">HardSigmoid</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-1">1</a>|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid">HardSigmoid</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-1">1</a>|||
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax">Hardmax</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Hardmax-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Hardmax-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Hardmax-1">1</a>|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity">Identity</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16">16</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14">14</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1">1</a>|||
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#If">If</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#If-16">16</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#If-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#If-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#If-1">1</a>|
Expand Down
22 changes: 12 additions & 10 deletions wonnx-py/tests/test_onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
backend_test = onnx.backend.test.BackendTest(DummyBackend, __name__)



backend_test.include(f"test_constant_cpu")
backend_test.include(f"test_conv_[a-z,_]*")
backend_test.include(f"test_Conv2d[a-z,_]*")
Expand All @@ -147,6 +146,9 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
backend_test.include(f"test_size_[a-z,_]*")
backend_test.include(f"test_celu_[a-z,_]*")

# Disabled until CastLike is implemented
# backend_test.include(f"test_hardsigmoid_[a-z,_]*")

# For these we only test the default version, as we don't support the bool type
backend_test.include(f"test_prelu_broadcast_cpu$")
backend_test.include(f"test_elu_cpu$")
Expand All @@ -162,15 +164,15 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
# Disable tests for ReduceSum because ReduceSum accepts the 'axes' list as input instead of as an attribute, and the test
# case sets the 'axes' input dynamically, which we don't support (yet?).
# backend_test.include(f"test_reduce_sum_[a-z,_]*")
#backend_test.include(f"test_reduce_mean_[a-z,_]*")
#backend_test.include(f"test_reduce_l1_[a-z,_]*")
#backend_test.include(f"test_reduce_l2_[a-z,_]*")
#backend_test.include(f"test_reduce_min_[a-z,_]*")
#backend_test.include(f"test_reduce_prod_[a-z,_]*")
#backend_test.include(f"test_reduce_sum_square_[a-z,_]*")
#backend_test.include(f"test_reduce_max_[a-z,_]*")
#backend_test.include(f"test_reduce_log_sum_[a-z,_]*")
#backend_test.include(f"test_reduce_log_sum_exp_[a-z,_]*")
# backend_test.include(f"test_reduce_mean_[a-z,_]*")
# backend_test.include(f"test_reduce_l1_[a-z,_]*")
# backend_test.include(f"test_reduce_l2_[a-z,_]*")
# backend_test.include(f"test_reduce_min_[a-z,_]*")
# backend_test.include(f"test_reduce_prod_[a-z,_]*")
# backend_test.include(f"test_reduce_sum_square_[a-z,_]*")
# backend_test.include(f"test_reduce_max_[a-z,_]*")
# backend_test.include(f"test_reduce_log_sum_[a-z,_]*")
# backend_test.include(f"test_reduce_log_sum_exp_[a-z,_]*")

# Takes dynamic input, we don't support that yet
# backend_test.include(f"test_constantofshape_[a-z,_]*")
Expand Down
16 changes: 12 additions & 4 deletions wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,13 +743,21 @@ pub fn compile(
}
}
op @ ("Relu" | "Sigmoid" | "Softsign" | "Softplus" | "Clip" | "Celu" | "Elu"
| "LeakyRelu") => {
let alpha = if op == "LeakyRelu" {
node.get_attribute_value("alpha", Some(0.01))?
| "LeakyRelu" | "HardSigmoid") => {
let alpha = match op {
"LeakyRelu" => node.get_attribute_value("alpha", Some(0.01))?,
"HardSigmoid" => node.get_attribute_value("alpha", Some(0.2))?,
_ => node.get_attribute_value("alpha", Some(1.0))?,
};

let beta = if op == "HardSigmoid" {
node.get_attribute_value("beta", Some(0.5))?
} else {
node.get_attribute_value("alpha", Some(1.0))?
node.get_attribute_value("beta", Some(1.0))?
};

context.insert("alpha", &alpha);
context.insert("beta", &beta);

if op == "Clip" {
let min: Vec<f32> =
Expand Down
9 changes: 9 additions & 0 deletions wonnx/templates/snippets/activation_scalar.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@
{{ scalar_type }}({{ alpha }}) * (exp(input_vec) - {{ scalar_type }}(1))
);

{%- elif activation_type == "HardSigmoid" -%}
{{ activation_output }} = max(
{{ scalar_type }}(0),
min(
{{ scalar_type }}(1),
{{ scalar_type }}({{ alpha }}) * {{ activation_input }} + {{ scalar_type }}({{ beta }})
)
);

{%- elif activation_output != activation_input -%}
{{ activation_output }} = {{ activation_input }};

Expand Down
9 changes: 9 additions & 0 deletions wonnx/templates/snippets/activation_vec.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@
{{ activation_output }} = max({{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar()))
+ min({{ scalar_type }}({{ alpha }}) * {{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar()));

{%- elif activation_type == "HardSigmoid" -%}
{{ activation_output }} = max(
Vec4(Scalar(), Scalar(), Scalar(), Scalar()),
min(
Vec4({{ scalar_type }}(1), {{ scalar_type }}(1), {{ scalar_type }}(1), {{ scalar_type }}(1)),
{{ scalar_type }}({{ alpha }}) * {{ activation_input }} + {{ scalar_type }}({{ beta }})
)
);

{%- elif activation_output != activation_input -%}
{{ activation_output }} = {{ activation_input }};

Expand Down
78 changes: 78 additions & 0 deletions wonnx/tests/hardsigmoid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use std::{collections::HashMap, convert::TryInto};
use wonnx::utils::{attribute, graph, model, node, tensor};
mod common;

/// Test HardSigmoid node with default alpha and beta
/// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-68
#[test]
fn test_hardsigmoid_default() {
let input_data = [-2.0, -1.0, 1.0, 2.0];
let shape = vec![2, 2];

let (default_alpha, default_beta) = (0.2, 0.5);
let expected_output_data: Vec<f32> = input_data
.iter()
.map(|x| x * default_alpha + default_beta)
.collect();

let mut model_input = HashMap::new();
model_input.insert("X".to_string(), input_data.as_slice().into());

let node = node(vec!["X"], vec!["Y"], "hard_sigmoid", "HardSigmoid", vec![]);

let model = model(graph(
vec![tensor("X", &shape)],
vec![tensor("Y", &shape)],
vec![],
vec![],
vec![node],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let output = pollster::block_on(session.run(&model_input)).unwrap();
let output_data: &[f32] = (&output["Y"]).try_into().unwrap();

common::assert_eq_vector(output_data, expected_output_data.as_slice());
}

/// Test HardSigmoid node with predefined alpha and beta
/// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-68
#[test]
fn test_hardsigmoid() {
let input_data: Vec<f32> = vec![-1.0, 0.0, 1.0];
let shape = vec![1, 3];

let mut model_input = HashMap::new();
model_input.insert("X".to_string(), input_data.as_slice().into());

let alpha = attribute("alpha", 0.5);
let beta = attribute("beta", 0.6);

let node = node(
vec!["X"],
vec!["Y"],
"hard_sigmoid",
"HardSigmoid",
vec![alpha, beta],
);

let model = model(graph(
vec![tensor("X", &shape)],
vec![tensor("Y", &shape)],
vec![],
vec![],
vec![node],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let output = pollster::block_on(session.run(&model_input)).unwrap();
println!("{:?}", output);

let expected_output = &[0.1, 0.6, 1.0];
let output_data: &[f32] = (&output["Y"]).try_into().unwrap();
common::assert_eq_vector(output_data, expected_output);
}

0 comments on commit 2b72b51

Please sign in to comment.