From 405243e4a9af6ce15add19bba518b67d70a5333b Mon Sep 17 00:00:00 2001 From: CameronTofer Date: Sun, 26 Mar 2023 23:33:05 +0200 Subject: [PATCH 1/6] implement LRN (cherry picked from commit 09484327a0a90f42de8d09930f16b7d44c7bb58e) --- wonnx/src/compiler.rs | 37 ++++++++++++++++ wonnx/templates/matrix/lrn.wgsl | 23 ++++++++++ wonnx/tests/localresponsenormalization.rs | 52 +++++++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 wonnx/templates/matrix/lrn.wgsl create mode 100644 wonnx/tests/localresponsenormalization.rs diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index fe9a3fcf..f4cbf476 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -83,6 +83,11 @@ lazy_static! { include_str!("../templates/matrix/transpose.wgsl"), ) .unwrap(); + tera.add_raw_template( + "matrix/lrn.wgsl", + include_str!("../templates/matrix/lrn.wgsl"), + ) + .unwrap(); tera.add_raw_template( "pool/aggregate.wgsl", include_str!("../templates/pool/aggregate.wgsl"), @@ -1321,6 +1326,38 @@ pub fn compile( threads: (ceil(output_lengths[0], 256) as _, 1, 1), } } + "LocalResponseNormalization" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#lrn + let alpha = get_attribute("alpha", Some(0.0001), node)?; + let beta = get_attribute("beta", Some(0.75), node)?; + let bias = get_attribute("bias", Some(1.0), node)?; + let size = get_attribute("size", Some(1), node)?; + + context.insert("alpha", &alpha); + context.insert("beta", &beta); + context.insert("bias", &bias); + context.insert("size", &size); + + let left_size = f64::floor((size - 1) as f64 / 2.0) as u32; + let right_size = f64::ceil((size - 1) as f64 / 2.0) as u32; + + context.insert("left_size", &left_size); + context.insert("right_size", &right_size); + + let (x_threads, workgroup_size_x) = workgroup_size( + output_lengths[0], + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_X, + )?; + context.insert("workgroup_size_x", &workgroup_size_x); + context.insert("i_chunks", &input_chunks); + + NodeTemplate { + scalar_type: agreed_type(input_shapes, output_shapes)?, + template: "matrix/lrn.wgsl", + threads: (x_threads, 1, 1), + } + } op => return Err(CompileError::UnimplementedOp(op.to_string())), }; diff --git a/wonnx/templates/matrix/lrn.wgsl b/wonnx/templates/matrix/lrn.wgsl new file mode 100644 index 00000000..d18b3518 --- /dev/null +++ b/wonnx/templates/matrix/lrn.wgsl @@ -0,0 +1,23 @@ +{%- include "structs.wgsl" -%} + +@group(0) @binding(0) +var input_0: Array; + +@group(0) @binding(1) +var output_0: Array; + +@compute @workgroup_size({{ workgroup_size_x }}) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let c = global_id.x; + //let chunk_start = {{ i_chunks[0][1] }}u * c; + let start = (c / {{ i_shape[0][1] }}u) * {{ i_shape[0][1] }}u; + let end = start + {{ i_shape[0][1] - 1 }}u; + + var square_sum: Scalar = Scalar(); + for (var i = max(start, c - {{left_size}}u); i <= min(end, c + {{right_size}}u); i++) { + let I = input_0.data[i]; + square_sum += I * I; + } + + output_0.data[c] = input_0.data[ c ] / pow({{bias}} + ({{alpha}} / {{size}}.0) * square_sum,{{beta}}); +} diff --git a/wonnx/tests/localresponsenormalization.rs b/wonnx/tests/localresponsenormalization.rs new file mode 100644 index 00000000..ddb4a918 --- /dev/null +++ b/wonnx/tests/localresponsenormalization.rs @@ -0,0 +1,52 @@ +use std::{collections::HashMap, convert::TryInto}; +use wonnx::utils::{attribute, graph, model, node, tensor}; +mod common; + +#[test] +fn local_response_normalization() { + let mut input_data = HashMap::new(); + + let batches = 1; + let width_height: usize = 3; + let channels: usize = 4; + let data: Vec = [ 1.,1.,2.,4., 2.,2.,1.,2., 3.,1.,2.,1., 4.,2.,3.,5., 3.,3.,2.,2., 6.,2.,3.,1., 7.,3.,4.,2., 8.,4.,3.,2., 9.,3.,4.,4.].to_vec(); + + let shape = vec![ + batches as i64, + channels as i64, + width_height as i64, + width_height as i64, + ]; + input_data.insert("X".to_string(), data.as_slice().into()); + + let bn_model = model(graph( + vec![tensor("X", &shape)], // input + vec![tensor("Y", &shape)], // output + vec![], // infos + vec![], // intializers + + // nodes + vec![node( + vec!["X"], + vec!["Y"], + "lrn", + "LocalResponseNormalization", + vec![ attribute("alpha", 1.0), + attribute("beta", 1.0), + attribute("bias", 0.0), + attribute("size", 2)], + )], + )); + + // LOGIC + let session = + pollster::block_on(wonnx::Session::from_model(bn_model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + let out_y = &result["Y"]; + + common::assert_eq_vector( + out_y.try_into().unwrap(), + &[1.0, 0.4, 0.2, 0.5, 0.5, 0.8, 0.4, 1.0, 0.6, 0.4, 0.8, 2.0, 0.4, 0.30769232, 0.1764706, 0.39999998, 0.33333334, 0.4615385, 0.5, 1.0, 0.3, 0.30769232, 0.6, 2.0, 0.2413793, 0.24, 0.4, 1.0, 0.2, 0.32, 0.4615385, 1.0, 0.2, 0.24, 0.25, 0.5], + ); +} \ No newline at end of file From cae7006270c569c8425b180433867bbc311ac134 Mon Sep 17 00:00:00 2001 From: CameronTofer Date: Sun, 26 Mar 2023 23:33:05 +0200 Subject: [PATCH 2/6] got name wrong. (cherry picked from commit 66f92d797bd09d531e7444a69261e5bd19509ad8) --- wonnx/src/compiler.rs | 2 +- wonnx/tests/localresponsenormalization.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index f4cbf476..108b0f30 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -1326,7 +1326,7 @@ pub fn compile( threads: (ceil(output_lengths[0], 256) as _, 1, 1), } } - "LocalResponseNormalization" => { + "LRN" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#lrn let alpha = get_attribute("alpha", Some(0.0001), node)?; let beta = get_attribute("beta", Some(0.75), node)?; diff --git a/wonnx/tests/localresponsenormalization.rs b/wonnx/tests/localresponsenormalization.rs index ddb4a918..c52aff6c 100644 --- a/wonnx/tests/localresponsenormalization.rs +++ b/wonnx/tests/localresponsenormalization.rs @@ -30,7 +30,7 @@ fn local_response_normalization() { vec!["X"], vec!["Y"], "lrn", - "LocalResponseNormalization", + "LRN", vec![ attribute("alpha", 1.0), attribute("beta", 1.0), attribute("bias", 0.0), From 9c3b756188a122de36b824e612d57ce94c492e6c Mon Sep 17 00:00:00 2001 From: Tommy van der Vorst Date: Sun, 26 Mar 2023 23:33:05 +0200 Subject: [PATCH 3/6] chore: get_attribute was renamed --- wonnx/src/compiler.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index 108b0f30..3c66f66e 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -1328,16 +1328,16 @@ pub fn compile( } "LRN" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#lrn - let alpha = get_attribute("alpha", Some(0.0001), node)?; - let beta = get_attribute("beta", Some(0.75), node)?; - let bias = get_attribute("bias", Some(1.0), node)?; - let size = get_attribute("size", Some(1), node)?; + let alpha = node.get_attribute_value("alpha", Some(0.0001))?; + let beta = node.get_attribute_value("beta", Some(0.75))?; + let bias = node.get_attribute_value("bias", Some(1.0))?; + let size = node.get_attribute_value("size", Some(1))?; context.insert("alpha", &alpha); context.insert("beta", &beta); context.insert("bias", &bias); context.insert("size", &size); - + let left_size = f64::floor((size - 1) as f64 / 2.0) as u32; let right_size = f64::ceil((size - 1) as f64 / 2.0) as u32; @@ -1348,7 +1348,7 @@ pub fn compile( output_lengths[0], MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, MAX_WORKGROUP_SIZE_X, - )?; + )?; context.insert("workgroup_size_x", &workgroup_size_x); context.insert("i_chunks", &input_chunks); From 992cdcb856aa3c42d3f5fa3bce137efd1ca781b7 Mon Sep 17 00:00:00 2001 From: Tommy van der Vorst Date: Sun, 26 Mar 2023 23:33:05 +0200 Subject: [PATCH 4/6] fix: type error in LRN WGSL --- wonnx/templates/matrix/lrn.wgsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wonnx/templates/matrix/lrn.wgsl b/wonnx/templates/matrix/lrn.wgsl index d18b3518..c3517ab2 100644 --- a/wonnx/templates/matrix/lrn.wgsl +++ b/wonnx/templates/matrix/lrn.wgsl @@ -19,5 +19,5 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { square_sum += I * I; } - output_0.data[c] = input_0.data[ c ] / pow({{bias}} + ({{alpha}} / {{size}}.0) * square_sum,{{beta}}); + output_0.data[c] = input_0.data[ c ] / pow({{ scalar_type }}({{ bias }}) + ({{ scalar_type }}({{ alpha }}) / {{ scalar_type }}({{ size }})) * square_sum, {{ scalar_type }}({{ beta }})); } From de3a2202c3e9c2eab12f5c9d8b154b096541c6c8 Mon Sep 17 00:00:00 2001 From: Tommy van der Vorst Date: Sun, 26 Mar 2023 23:33:05 +0200 Subject: [PATCH 5/6] chore: formatting --- wonnx/tests/localresponsenormalization.rs | 33 ++++++++++++++--------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/wonnx/tests/localresponsenormalization.rs b/wonnx/tests/localresponsenormalization.rs index c52aff6c..8f4202fa 100644 --- a/wonnx/tests/localresponsenormalization.rs +++ b/wonnx/tests/localresponsenormalization.rs @@ -9,7 +9,11 @@ fn local_response_normalization() { let batches = 1; let width_height: usize = 3; let channels: usize = 4; - let data: Vec = [ 1.,1.,2.,4., 2.,2.,1.,2., 3.,1.,2.,1., 4.,2.,3.,5., 3.,3.,2.,2., 6.,2.,3.,1., 7.,3.,4.,2., 8.,4.,3.,2., 9.,3.,4.,4.].to_vec(); + let data: Vec = [ + 1., 1., 2., 4., 2., 2., 1., 2., 3., 1., 2., 1., 4., 2., 3., 5., 3., 3., 2., 2., 6., 2., 3., + 1., 7., 3., 4., 2., 8., 4., 3., 2., 9., 3., 4., 4., + ] + .to_vec(); let shape = vec![ batches as i64, @@ -20,21 +24,22 @@ fn local_response_normalization() { input_data.insert("X".to_string(), data.as_slice().into()); let bn_model = model(graph( - vec![tensor("X", &shape)], // input - vec![tensor("Y", &shape)], // output - vec![], // infos - vec![], // intializers - + vec![tensor("X", &shape)], // input + vec![tensor("Y", &shape)], // output + vec![], // infos + vec![], // intializers // nodes vec![node( vec!["X"], vec!["Y"], "lrn", "LRN", - vec![ attribute("alpha", 1.0), - attribute("beta", 1.0), - attribute("bias", 0.0), - attribute("size", 2)], + vec![ + attribute("alpha", 1.0), + attribute("beta", 1.0), + attribute("bias", 0.0), + attribute("size", 2), + ], )], )); @@ -47,6 +52,10 @@ fn local_response_normalization() { common::assert_eq_vector( out_y.try_into().unwrap(), - &[1.0, 0.4, 0.2, 0.5, 0.5, 0.8, 0.4, 1.0, 0.6, 0.4, 0.8, 2.0, 0.4, 0.30769232, 0.1764706, 0.39999998, 0.33333334, 0.4615385, 0.5, 1.0, 0.3, 0.30769232, 0.6, 2.0, 0.2413793, 0.24, 0.4, 1.0, 0.2, 0.32, 0.4615385, 1.0, 0.2, 0.24, 0.25, 0.5], + &[ + 1.0, 0.4, 0.2, 0.5, 0.5, 0.8, 0.4, 1.0, 0.6, 0.4, 0.8, 2.0, 0.4, 0.30769232, 0.1764706, + 0.39999998, 0.33333334, 0.4615385, 0.5, 1.0, 0.3, 0.30769232, 0.6, 2.0, 0.2413793, + 0.24, 0.4, 1.0, 0.2, 0.32, 0.4615385, 1.0, 0.2, 0.24, 0.25, 0.5, + ], ); -} \ No newline at end of file +} From 11e4876abcd7eaf4d31d070805087015591bb7f2 Mon Sep 17 00:00:00 2001 From: Tommy van der Vorst Date: Sun, 26 Mar 2023 23:33:28 +0200 Subject: [PATCH 6/6] fix(docs): update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3582ea9d..38c8e64c 100644 --- a/README.md +++ b/README.md @@ -272,7 +272,7 @@ fn test_matmul_square_matrix() { |InstanceNormalization|6, 1| |IsInf|10| |IsNaN|13, 9| -|LRN|13, 1|| +|LRN|13, 1|✅|| |LSTM|14, 7, 1| |LeakyRelu|6, 1|✅|✅| |Less|13, 9, 7, 1|✅|