diff --git a/wonnx/tests/matrix.rs b/wonnx/tests/matrix.rs index e138e611..e0cd2f5d 100644 --- a/wonnx/tests/matrix.rs +++ b/wonnx/tests/matrix.rs @@ -374,6 +374,42 @@ fn test_pad_reflect() { assert_eq!(actual, &test_y); } +#[test] +fn test_pad_reflect_complex() { + let mut input_data = HashMap::new(); + #[rustfmt::skip] + let data = [ + 1.0, 1.2, 1.3, + 2.3, 3.4, 4.5, + 4.5, 5.7, 6.8, + ].to_vec(); + input_data.insert("X".to_string(), data.as_slice().into()); + + let model = model(graph( + vec![tensor("X", &[3, 3])], + vec![tensor("Y", &[3, 7])], + vec![], + vec![initializer_int64("pads", vec![0, 2, 0, 2], vec![4])], + vec![node(vec!["X", "pads"], vec!["Y"], "Pad", "Pad", vec![ + attribute("mode", "reflect"), + ])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + #[rustfmt::skip] + let test_y = vec![ + 1.3, 1.2, 1.0, 1.2, 1.3, 1.2, 1.0, + 4.5, 3.4, 2.3, 3.4, 4.5, 3.4, 2.3, + 6.8, 5.7, 4.5, 5.7, 6.8, 5.7, 4.5, + ]; + let actual: &[_] = (&result["Y"]).try_into().unwrap(); + // No arithmetic is done, so `assert_eq!` can be used. + assert_eq!(actual, &test_y); +} + #[test] fn test_resize() { let _ = env_logger::builder().is_test(true).try_init();