Skip to content

Commit

Permalink
add compiler support and test for pad reflect
Browse files Browse the repository at this point in the history
  • Loading branch information
redthing1 committed Jul 30, 2023
1 parent 78958f2 commit 19737cb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
2 changes: 2 additions & 0 deletions wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1308,13 +1308,15 @@ pub fn compile(
let mode = node.get_attribute_value("mode", Some("constant".to_string()))?;
match mode.as_str() {
"constant" => {}
"reflect" => {}
_ => {
return Err(CompileError::UnimplementedVariant {
op: String::from("Pad"),
variant: format!("mode={}", mode),
})
}
}
context.insert("mode", &mode);

let pads: Vec<i64> = node.get_attribute_value("pads", None)?;
if pads.len() != input_shapes[0].rank() * 2 {
Expand Down
36 changes: 36 additions & 0 deletions wonnx/tests/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,42 @@ fn test_pad_complex() {
assert_eq!(actual, &test_y);
}

#[test]
fn test_pad_reflect() {
let mut input_data = HashMap::new();
#[rustfmt::skip]
let data = [
1.0, 1.2,
2.3, 3.4,
4.5, 5.7,
].to_vec();
input_data.insert("X".to_string(), data.as_slice().into());

let model = model(graph(
vec![tensor("X", &[3, 2])],
vec![tensor("Y", &[3, 4])],
vec![],
vec![initializer_int64("pads", vec![0, 2, 0, 0], vec![4])],
vec![node(vec!["X", "pads"], vec!["Y"], "Pad", "Pad", vec![
attribute("mode", "reflect"),
])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("session did not create");
let result = pollster::block_on(session.run(&input_data)).unwrap();

#[rustfmt::skip]
let test_y = vec![
1.0, 1.2, 1.0, 1.2,
2.3, 3.4, 2.3, 3.4,
4.5, 5.7, 4.5, 5.7,
];
let actual: &[_] = (&result["Y"]).try_into().unwrap();
// No arithmetic is done, so `assert_eq!` can be used.
assert_eq!(actual, &test_y);
}

#[test]
fn test_resize() {
let _ = env_logger::builder().is_test(true).try_init();
Expand Down

0 comments on commit 19737cb

Please sign in to comment.