Skip to content

Commit

Permalink
Fix conversion for the resulting type of the torch operation (#54)
Browse files Browse the repository at this point in the history
Fix conversion for the resulting type of the torch operation
  • Loading branch information
josel-amd authored Jun 19, 2024
1 parent 5a3e509 commit ee76171
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
29 changes: 15 additions & 14 deletions lib/Conversion/XTenNNToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,25 @@ using namespace mlir::torch;

namespace {

Value toTorchTensorTypeCast(PatternRewriter &rewriter, Value input) {
Type toTorchTensorTypeCast(PatternRewriter &rewriter, ShapedType ty) {
auto elementType = ty.getElementType();

auto tensorTy = dyn_cast<ShapedType>(input.getType());
auto sizes = tensorTy.getShape();
auto tensorEltTy = tensorTy.getElementType();
if (tensorEltTy.isSignlessInteger()) {
tensorEltTy = rewriter.getIntegerType(tensorTy.getElementTypeBitWidth(), true);
auto intElementType = dyn_cast<IntegerType>(ty.getElementType());
if (intElementType && intElementType.isSignlessInteger() && intElementType.getWidth() != 1) {
elementType = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth(),
/*isSigned=*/true);
}

return Torch::ValueTensorType::get(ty.getContext(), ty.getShape(),
elementType);
}

Value toTorchTensorTypeCast(PatternRewriter &rewriter, Value input) {
auto tensorTy = cast<ShapedType>(input.getType());
return rewriter
.create<TorchConversion::FromBuiltinTensorOp>(
input.getLoc(),
mlir::torch::Torch::ValueTensorType::get(input.getContext(), sizes,
tensorEltTy),
input.getLoc(), toTorchTensorTypeCast(rewriter, tensorTy),

input)
.getResult();
}
Expand Down Expand Up @@ -296,8 +301,6 @@ class ApplyXTenNNToTorch : public OpConversionPattern<SrcOpT> {
matchAndRewrite(SrcOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto *ctx = op->getContext();

SmallVector<Value> vtensorOperands;
llvm::transform(
op->getOperands(), std::back_inserter(vtensorOperands),
Expand All @@ -307,9 +310,7 @@ class ApplyXTenNNToTorch : public OpConversionPattern<SrcOpT> {
SmallVector<Type> vtensorResultTypes;
llvm::transform(op->getResultTypes(),
std::back_inserter(vtensorResultTypes), [&](Type ty) {
auto tensorTy = cast<TensorType>(ty);
return Torch::ValueTensorType::get(
ctx, tensorTy.getShape(), tensorTy.getElementType());
return toTorchTensorTypeCast(rewriter, cast<ShapedType>(ty));
});

// Call the function that creates the new operation.
Expand Down
31 changes: 30 additions & 1 deletion test/Conversion/XTenNNToTorch/reflect_pad.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func.func @reflect_pad_bf16(%arg0: tensor<1x32x122x122xbf16>) -> tensor<1x32x124
// CHECK: return %[[TO]] : tensor<1x32x124x124xbf16>

// -----

func.func @reflect_pad_f32(%arg0: tensor<1x32x122x122xf32>) -> tensor<1x32x124x124xf32> {
%pad = "tosa.const"() <{value = dense<[0, 0, 1, 1, 0, 0, 1, 1]> : tensor<8xi64>}> : () -> tensor<8xi64>
%reflect_pad = xten_nn.reflect_pad %arg0, %pad {LayerName = "Pad_282", OutputName = "Pad_282"} : (tensor<1x32x122x122xf32>, tensor<8xi64>) -> (tensor<1x32x124x124xf32>)
Expand All @@ -29,4 +30,32 @@ func.func @reflect_pad_f32(%arg0: tensor<1x32x122x122xf32>) -> tensor<1x32x124x1
// CHECK: %[[FROM_PADS:.*]] = torch_c.from_builtin_tensor %[[PADS]] : tensor<8xi64> -> !torch.vtensor<[8],si64>
// CHECK: %[[OP:.*]] = torch.operator "onnx.Pad"(%[[FROM_ARG]], %[[FROM_PADS]]) {torch.onnx.mode = "reflect"} : (!torch.vtensor<[1,32,122,122],f32>, !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,32,124,124],f32>
// CHECK: %[[TO:.*]] = torch_c.to_builtin_tensor %[[OP]] : !torch.vtensor<[1,32,124,124],f32> -> tensor<1x32x124x124xf32>
// CHECK: return %[[TO]] : tensor<1x32x124x124xf32>
// CHECK: return %[[TO]] : tensor<1x32x124x124xf32>

// -----

func.func @reflect_pad_i32(%arg0: tensor<1x3x4x5xi32> , %arg1: tensor<8xi64>) -> tensor<1x3x6x7xi32> {
%0 = xten_nn.reflect_pad %arg0, %arg1 : (tensor<1x3x4x5xi32>, tensor<8xi64>) -> tensor<1x3x6x7xi32>
return %0 : tensor<1x3x6x7xi32>
}

// CHECK-LABEL: func @reflect_pad_i32
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<1x3x4x5xi32>, %[[ARG_1:.*]]: tensor<8xi64>) -> tensor<1x3x6x7xi32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
// CHECK: %[[FROM_ARG_0:.*]] = torch_c.from_builtin_tensor %[[ARG_0]] : tensor<1x3x4x5xi32> -> !torch.vtensor<[1,3,4,5],si32>
// CHECK: %[[FROM_ARG_1:.*]] = torch_c.from_builtin_tensor %[[ARG_1]] : tensor<8xi64> -> !torch.vtensor<[8],si64>
// CHECK: %[[OP:.*]] = torch.operator "onnx.Pad"(%[[FROM_ARG_0]], %[[FROM_ARG_1]]) {torch.onnx.mode = "reflect"} : (!torch.vtensor<[1,3,4,5],si32>, !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,3,6,7],si32>
// CHECK: %[[TO:.*]] = torch_c.to_builtin_tensor %[[OP]] : !torch.vtensor<[1,3,6,7],si32> -> tensor<1x3x6x7xi32>
// CHECK: return %[[TO]] : tensor<1x3x6x7xi32>

func.func @reflect_pad_i1(%arg0: tensor<1x3x4x5xi1> , %arg1: tensor<8xi64>) -> tensor<1x3x6x7xi1> {
%0 = xten_nn.reflect_pad %arg0, %arg1 : (tensor<1x3x4x5xi1>, tensor<8xi64>) -> tensor<1x3x6x7xi1>
return %0 : tensor<1x3x6x7xi1>
}

// CHECK-LABEL: func @reflect_pad_i1
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<1x3x4x5xi1>, %[[ARG_1:.*]]: tensor<8xi64>) -> tensor<1x3x6x7xi1> attributes {torch.onnx_meta.opset_version = 19 : si64} {
// CHECK: %[[FROM_ARG_0:.*]] = torch_c.from_builtin_tensor %[[ARG_0]] : tensor<1x3x4x5xi1> -> !torch.vtensor<[1,3,4,5],i1>
// CHECK: %[[FROM_ARG_1:.*]] = torch_c.from_builtin_tensor %[[ARG_1]] : tensor<8xi64> -> !torch.vtensor<[8],si64>
// CHECK: %[[OP:.*]] = torch.operator "onnx.Pad"(%[[FROM_ARG_0]], %[[FROM_ARG_1]]) {torch.onnx.mode = "reflect"} : (!torch.vtensor<[1,3,4,5],i1>, !torch.vtensor<[8],si64>) -> !torch.vtensor<[1,3,6,7],i1>
// CHECK: %[[TO:.*]] = torch_c.to_builtin_tensor %[[OP]] : !torch.vtensor<[1,3,6,7],i1> -> tensor<1x3x6x7xi1>
// CHECK: return %[[TO]] : tensor<1x3x6x7xi1>

0 comments on commit ee76171

Please sign in to comment.