Skip to content

Commit

Permalink
Merge pull request #115 from Xilinx/tiagot.support_more_configuration…
Browse files Browse the repository at this point in the history
…s_gridsample

feat: add 'reflection' padding_mode and 'nearest' and 'cubic' mode in xten_nn.grid_sample.
  • Loading branch information
ttjost authored Nov 29, 2024
2 parents 8eb043f + 3b29650 commit eb1cc2b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,11 +608,11 @@ static std::string getOpInvalidModeOption(ArrayRef<const char *> subOptions,

LogicalResult amd::xten_nn::GridSampleOp::verify() {

constexpr std::array mode{"bilinear"};
constexpr std::array mode{"bilinear", "nearest", "cubic"};
if (getMode() > mode.size() - 1) {
return emitOpError(getOpInvalidModeOption(mode, getModeAttrName()));
}
constexpr std::array paddingMode{"zeros", "border"};
constexpr std::array paddingMode{"zeros", "border", "reflection"};
if (getPaddingMode() > paddingMode.size() - 1) {
return emitOpError(
getOpInvalidModeOption(paddingMode, getPaddingModeAttrName()));
Expand Down
15 changes: 7 additions & 8 deletions test/Dialect/XTenNN/grid_sample.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@ func.func @test_grid_sample_valid3(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tens
return %0 : tensor<*xf32>
}

// -----

func.func @test_grid_sample_no_padding_mode(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
// expected-error@+1 {{Valid values for 'padding_mode' option are: 'zeros'(0), 'border'(1)}}
func.func @test_grid_sample_reflection(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
%0 = "xten_nn.grid_sample"(%arg0, %arg1) {align_corners = 1 : i64, mode = 0 : i64, padding_mode = 2 : i64} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// -----

func.func @test_grid_sample_interpolate_mode(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
// expected-error@+1 {{Valid values for 'mode' option are: 'bilinear'(0)}}
func.func @test_grid_sample_interpolate_nearest(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
%0 = "xten_nn.grid_sample"(%arg0, %arg1) {align_corners = 1 : i64, mode = 1 : i64, padding_mode = 1 : i64} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

func.func @test_grid_sample_interpolate_cubic(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> {
%0 = "xten_nn.grid_sample"(%arg0, %arg1) {align_corners = 1 : i64, mode = 2 : i64, padding_mode = 1 : i64} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

0 comments on commit eb1cc2b

Please sign in to comment.