From 3b29650684f62370cbc631e50f278219c62e385e Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Fri, 29 Nov 2024 16:55:08 +0100 Subject: [PATCH] feat: add 'reflection' padding_mode and 'nearest' and 'cubic' mode in xten_nn.grid_sample. --- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 4 ++-- test/Dialect/XTenNN/grid_sample.mlir | 15 +++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index ae5ccd0..e658284 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -608,11 +608,11 @@ static std::string getOpInvalidModeOption(ArrayRef subOptions, LogicalResult amd::xten_nn::GridSampleOp::verify() { - constexpr std::array mode{"bilinear"}; + constexpr std::array mode{"bilinear", "nearest", "cubic"}; if (getMode() > mode.size() - 1) { return emitOpError(getOpInvalidModeOption(mode, getModeAttrName())); } - constexpr std::array paddingMode{"zeros", "border"}; + constexpr std::array paddingMode{"zeros", "border", "reflection"}; if (getPaddingMode() > paddingMode.size() - 1) { return emitOpError( getOpInvalidModeOption(paddingMode, getPaddingModeAttrName())); diff --git a/test/Dialect/XTenNN/grid_sample.mlir b/test/Dialect/XTenNN/grid_sample.mlir index 7a4b60b..4d44292 100644 --- a/test/Dialect/XTenNN/grid_sample.mlir +++ b/test/Dialect/XTenNN/grid_sample.mlir @@ -17,18 +17,17 @@ func.func @test_grid_sample_valid3(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tens return %0 : tensor<*xf32> } -// ----- - -func.func @test_grid_sample_no_padding_mode(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { - // expected-error@+1 {{Valid values for 'padding_mode' option are: 'zeros'(0), 'border'(1)}} +func.func @test_grid_sample_reflection(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { %0 = "xten_nn.grid_sample"(%arg0, %arg1) {align_corners = 1 : i64, mode = 0 : i64, padding_mode = 2 : i64} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// ----- - -func.func @test_grid_sample_interpolate_mode(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { - // expected-error@+1 {{Valid values for 'mode' option are: 'bilinear'(0)}} +func.func @test_grid_sample_interpolate_nearest(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { %0 = "xten_nn.grid_sample"(%arg0, %arg1) {align_corners = 1 : i64, mode = 1 : i64, padding_mode = 1 : i64} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +func.func @test_grid_sample_interpolate_cubic(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { + %0 = "xten_nn.grid_sample"(%arg0, %arg1) {align_corners = 1 : i64, mode = 2 : i64, padding_mode = 1 : i64} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +}