From ee7617184df68bcda0b3bb1d9baf4ff30afc877f Mon Sep 17 00:00:00 2001 From: josel-amd <166385423+josel-amd@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:42:34 +0200 Subject: [PATCH] Fix conversion for the resulting type of the torch operation (#54) Fix conversion for the resulting type of the torch operation --- lib/Conversion/XTenNNToTorch.cpp | 29 ++++++++--------- .../Conversion/XTenNNToTorch/reflect_pad.mlir | 31 ++++++++++++++++++- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/XTenNNToTorch.cpp b/lib/Conversion/XTenNNToTorch.cpp index eab4f97e..cb7dbed5 100644 --- a/lib/Conversion/XTenNNToTorch.cpp +++ b/lib/Conversion/XTenNNToTorch.cpp @@ -37,20 +37,25 @@ using namespace mlir::torch; namespace { -Value toTorchTensorTypeCast(PatternRewriter &rewriter, Value input) { +Type toTorchTensorTypeCast(PatternRewriter &rewriter, ShapedType ty) { + auto elementType = ty.getElementType(); - auto tensorTy = dyn_cast(input.getType()); - auto sizes = tensorTy.getShape(); - auto tensorEltTy = tensorTy.getElementType(); - if (tensorEltTy.isSignlessInteger()) { - tensorEltTy = rewriter.getIntegerType(tensorTy.getElementTypeBitWidth(), true); + auto intElementType = dyn_cast(ty.getElementType()); + if (intElementType && intElementType.isSignlessInteger() && intElementType.getWidth() != 1) { + elementType = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth(), + /*isSigned=*/true); } + return Torch::ValueTensorType::get(ty.getContext(), ty.getShape(), + elementType); +} + +Value toTorchTensorTypeCast(PatternRewriter &rewriter, Value input) { + auto tensorTy = cast(input.getType()); return rewriter .create( - input.getLoc(), - mlir::torch::Torch::ValueTensorType::get(input.getContext(), sizes, - tensorEltTy), + input.getLoc(), toTorchTensorTypeCast(rewriter, tensorTy), + input) .getResult(); } @@ -296,8 +301,6 @@ class ApplyXTenNNToTorch : public OpConversionPattern { matchAndRewrite(SrcOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto *ctx = op->getContext(); - SmallVector vtensorOperands; llvm::transform( op->getOperands(), std::back_inserter(vtensorOperands), @@ -307,9 +310,7 @@ class ApplyXTenNNToTorch : public OpConversionPattern { SmallVector vtensorResultTypes; llvm::transform(op->getResultTypes(), std::back_inserter(vtensorResultTypes), [&](Type ty) { - auto tensorTy = cast(ty); - return Torch::ValueTensorType::get( - ctx, tensorTy.getShape(), tensorTy.getElementType()); + return toTorchTensorTypeCast(rewriter, cast(ty)); }); // Call the function that creates the new operation. diff --git a/test/Conversion/XTenNNToTorch/reflect_pad.mlir b/test/Conversion/XTenNNToTorch/reflect_pad.mlir index f811ff1e..9a903657 100644 --- a/test/Conversion/XTenNNToTorch/reflect_pad.mlir +++ b/test/Conversion/XTenNNToTorch/reflect_pad.mlir @@ -17,6 +17,7 @@ func.func @reflect_pad_bf16(%arg0: tensor<1x32x122x122xbf16>) -> tensor<1x32x124 // CHECK: return %[[TO]] : tensor<1x32x124x124xbf16> // ----- + func.func @reflect_pad_f32(%arg0: tensor<1x32x122x122xf32>) -> tensor<1x32x124x124xf32> { %pad = "tosa.const"() <{value = dense<[0, 0, 1, 1, 0, 0, 1, 1]> : tensor<8xi64>}> : () -> tensor<8xi64> %reflect_pad = xten_nn.reflect_pad %arg0, %pad {LayerName = "Pad_282", OutputName = "Pad_282"} : (tensor<1x32x122x122xf32>, tensor<8xi64>) -> (tensor<1x32x124x124xf32>) @@ -29,4 +30,32 @@ func.func @reflect_pad_f32(%arg0: tensor<1x32x122x122xf32>) -> tensor<1x32x124x1 // CHECK: %[[FROM_PADS:.*]] = torch_c.from_builtin_tensor %[[PADS]] : tensor<8xi64> -> !torch.vtensor<[8],si64> // CHECK: %[[OP:.*]] = torch.operator "onnx.Pad"(%[[FROM_ARG]], %[[FROM_PADS]]) {torch.onnx.mode = "reflect"} : (!torch.vtensor<[1,32,122,122],f32>, !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,32,124,124],f32> // CHECK: %[[TO:.*]] = torch_c.to_builtin_tensor %[[OP]] : !torch.vtensor<[1,32,124,124],f32> -> tensor<1x32x124x124xf32> -// CHECK: return %[[TO]] : tensor<1x32x124x124xf32> \ No newline at end of file +// CHECK: return %[[TO]] : tensor<1x32x124x124xf32> + +// ----- + +func.func @reflect_pad_i32(%arg0: tensor<1x3x4x5xi32> , %arg1: tensor<8xi64>) -> tensor<1x3x6x7xi32> { + %0 = xten_nn.reflect_pad %arg0, %arg1 : (tensor<1x3x4x5xi32>, tensor<8xi64>) -> tensor<1x3x6x7xi32> + return %0 : tensor<1x3x6x7xi32> +} + +// CHECK-LABEL: func @reflect_pad_i32 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<1x3x4x5xi32>, %[[ARG_1:.*]]: tensor<8xi64>) -> tensor<1x3x6x7xi32> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[FROM_ARG_0:.*]] = torch_c.from_builtin_tensor %[[ARG_0]] : tensor<1x3x4x5xi32> -> !torch.vtensor<[1,3,4,5],si32> +// CHECK: %[[FROM_ARG_1:.*]] = torch_c.from_builtin_tensor %[[ARG_1]] : tensor<8xi64> -> !torch.vtensor<[8],si64> +// CHECK: %[[OP:.*]] = torch.operator "onnx.Pad"(%[[FROM_ARG_0]], %[[FROM_ARG_1]]) {torch.onnx.mode = "reflect"} : (!torch.vtensor<[1,3,4,5],si32>, !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,3,6,7],si32> +// CHECK: %[[TO:.*]] = torch_c.to_builtin_tensor %[[OP]] : !torch.vtensor<[1,3,6,7],si32> -> tensor<1x3x6x7xi32> +// CHECK: return %[[TO]] : tensor<1x3x6x7xi32> + +func.func @reflect_pad_i1(%arg0: tensor<1x3x4x5xi1> , %arg1: tensor<8xi64>) -> tensor<1x3x6x7xi1> { + %0 = xten_nn.reflect_pad %arg0, %arg1 : (tensor<1x3x4x5xi1>, tensor<8xi64>) -> tensor<1x3x6x7xi1> + return %0 : tensor<1x3x6x7xi1> +} + +// CHECK-LABEL: func @reflect_pad_i1 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<1x3x4x5xi1>, %[[ARG_1:.*]]: tensor<8xi64>) -> tensor<1x3x6x7xi1> attributes {torch.onnx_meta.opset_version = 19 : si64} { +// CHECK: %[[FROM_ARG_0:.*]] = torch_c.from_builtin_tensor %[[ARG_0]] : tensor<1x3x4x5xi1> -> !torch.vtensor<[1,3,4,5],i1> +// CHECK: %[[FROM_ARG_1:.*]] = torch_c.from_builtin_tensor %[[ARG_1]] : tensor<8xi64> -> !torch.vtensor<[8],si64> +// CHECK: %[[OP:.*]] = torch.operator "onnx.Pad"(%[[FROM_ARG_0]], %[[FROM_ARG_1]]) {torch.onnx.mode = "reflect"} : (!torch.vtensor<[1,3,4,5],i1>, !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,3,6,7],i1> +// CHECK: %[[TO:.*]] = torch_c.to_builtin_tensor %[[OP]] : !torch.vtensor<[1,3,6,7],i1> -> tensor<1x3x6x7xi1> +// CHECK: return %[[TO]] : tensor<1x3x6x7xi1>