From c7ef459f3344cff8224675d9c9a46651411a9da9 Mon Sep 17 00:00:00 2001 From: Bratislav Filipovic Date: Wed, 2 Oct 2024 13:46:09 +0200 Subject: [PATCH] [Torch] Decompose torch.take op into AtenFlattenUsingInts and AtenSelectIndex Decompose op inside DecomposeComplexOps.cpp Add tests into slice_like.py --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 47 +++++++++++++------ .../Torch/Transforms/DecomposeComplexOps.cpp | 33 +++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + .../build_tools/abstract_interp_lib_gen.py | 10 ++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/slice_like.py | 47 +++++++++++++++++++ 7 files changed, 148 insertions(+), 15 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0b1a8b25720e..bdb98e3f0a81 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13148,6 +13148,30 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [ }]; } +def Torch_AtenTakeOp : Torch_Op<"aten.take", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::take : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$index + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTakeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenTakeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 59cf69393ded..ade9b15da348 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6591,6 +6591,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.take\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.take\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: indexes must be integer types\"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.linalg_slogdet\"(%arg0: !torch.list) -> !torch.tuple, list> {\n" " %int-2 = torch.constant.int -2\n" " %int-1 = torch.constant.int -1\n" @@ -11238,21 +11270,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %3 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" -" }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %int11 = torch.constant.int 11\n" -" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index ed0ef9e5b4f0..55462ec7626b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4113,6 +4113,38 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { }; } // namespace +namespace { + +class DecomposeAtenTakeOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTakeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value index = op.getIndex(); + auto selfTy = cast(self.getType()); + auto resType = cast(op.getType()); + int64_t selfNumel = getTensorNumel(self).value(); // as selfTy has sizes + + auto flattenType = selfTy.getWithSizesAndDtype( + /*optionalSizes=*/{selfNumel}, resType.getDtype()); + Value constMinusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value constZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value flattenSelf = rewriter.create( + loc, flattenType, self, constZero, constMinusOne); + + Value result = rewriter.create( + loc, resType, flattenSelf, constZero, index); + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + // decompose aten.repeat_interleave.self_int into following ops: // aten.flatten.using_ints, aten.unsqueeze, aten.tile, aten.reshape namespace { @@ -9660,6 +9692,7 @@ class DecomposeComplexOpsPass legalOpsSet.clear(); legalOpsSet.insert(legalOps.begin(), legalOps.end()); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ebc43faa595c..7947971e9ca0 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -414,6 +414,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](AtenMatmulOp op) { std::optional lhsRank = getTensorRank(op.getSelf()); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index bc49757ee9d3..a9f3d887a2dd 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -257,6 +257,16 @@ def aten〇_linalg_det〡shape(A: List[int]) -> Tuple[List[int], List[int], List def aten〇_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]: return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1]) +def aten〇take〡shape(self: List[int], index: List[int]) -> List[int]: + return index + +def aten〇take〡dtype(self_rank_dtype: Tuple[int, int], index_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + index_rank, index_dtype = index_rank_dtype + assert is_integer_dtype(index_dtype), "indexes must be integer types" + return self_dtype + + def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]: assert len(A) == 2 or len(A) == 3 assert A[-1] == A[-2] diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5f53e17b9d17..b1fbac096525 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -964,6 +964,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)" ) + emit("aten::take : (Tensor, Tensor) -> (Tensor)") # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index deaf2fd6cac3..48579e81021b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -1121,3 +1121,50 @@ def forward(self, x): @register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule()) def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5)) + + +# ============================================================================== + + +class TakeModule(torch.nn.Module): + @export + @annotate_args([None, [(4, 4), torch.float32, True], [(4,), torch.int64, True]]) + def forward(self, input, index): + return torch.take(input, index) + + +@register_test_case(module_factory=lambda: TakeModule()) +def TakeModule_F32(module, tu: TestUtils): + A = tu.rand(4, 4).to(dtype=torch.float32) + index = tu.rand(4, low=0, high=torch.numel(A)).to(dtype=torch.int64) + module.forward(A, index) + + +class TakeBatchModule(torch.nn.Module): + @export + @annotate_args([None, [(4, 4, 4), torch.float32, True], [(4,), torch.int64, True]]) + def forward(self, input, index): + return torch.take(input, index) + + +@register_test_case(module_factory=lambda: TakeBatchModule()) +def TakeModuleBatched_F32(module, tu: TestUtils): + A = tu.rand(4, 4, 4).to(dtype=torch.float32) + index = tu.rand(4, low=0, high=torch.numel(A)).to(dtype=torch.int64) + module.forward(A, index) + + +class TakeDynamicModule(torch.nn.Module): + @export + @annotate_args( + [None, [(-1, -1, -1), torch.float32, True], [(4,), torch.int64, True]] + ) + def forward(self, input, index): + return torch.take(input, index) + + +@register_test_case(module_factory=lambda: TakeDynamicModule()) +def TakeModuleDynamic_F32(module, tu: TestUtils): + A = tu.rand(4, 4, 8).to(dtype=torch.float32) + index = tu.rand(4, low=0, high=torch.numel(A)).to(dtype=torch.int64) + module.forward(A, index)