From 47eea89a70c38a2966cf27a2cc466b1f83def1fa Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Wed, 10 Jan 2024 23:58:45 -0800 Subject: [PATCH 1/5] init --- test/stablehlo/test_mark_pattern.py | 5 +- .../runtime/stablehlo_composite_helper.cc | 98 ++++++++++++++----- 2 files changed, 77 insertions(+), 26 deletions(-) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index e0fcce201f7..cb1ecc5ae19 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -190,7 +190,7 @@ def forward(self, x, y): '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) - def test_multiple_input(self): + def test_multiple_inputs(self): def f(x, y): x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) @@ -205,8 +205,7 @@ def f(x, y): self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) self.assertTrue('{attributes = {}, name = "p"}' in stablehlo) - @unittest.skip("Multiple outputs patterns are not supported now.") - def test_multiple_output(self): + def test_multiple_outputs(self): def f(x, y): x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc index bc38fc65997..34830aa15f1 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc @@ -44,7 +44,7 @@ struct BoundaryMetadata { bool is_input; std::unordered_map attrs; - auto boundary_key() const { return std::forward_as_tuple(name, id); } + auto boundary_key() const { return absl::StrCat(name, "__@@__", id); } auto uid() const { return std::forward_as_tuple(name, id, pos, is_input); } @@ -116,10 +116,12 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { for (mlir::func::FuncOp& func_op : func_ops) { llvm::DenseMap op_order_map = BuildOpOrderMap(func_op); - for (auto op : func_op.getOps()) { - if (mlir::failed( - BuildStableHLOComposite(op.getOperation(), op_order_map))) { - op.emitError() << "failed to build composite."; + std::unordered_map> + boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op); + + for (const auto& [unused, ops] : boundary_output_ops_map) { + if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) { + func_op.emitError() << "failed to build composite."; return signalPassFailure(); } } @@ -144,6 +146,31 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { return op_order_map; } + std::unordered_map> + BuildBoundaryOutputOpsMap(mlir::func::FuncOp func_op) { + std::unordered_map> + boundary_output_ops; + + for (auto op : func_op.getOps()) { + auto metadata_or = GetBoundaryMetadata(op); + if (mlir::failed(metadata_or)) { + continue; + } + + std::unique_ptr metadata = std::move(*metadata_or); + if (metadata == nullptr || metadata->is_input) { + continue; + } + + auto& output_ops = boundary_output_ops[metadata->boundary_key()]; + if (metadata->pos >= output_ops.size()) { + output_ops.resize(metadata->pos + 1, nullptr); + } + output_ops[metadata->pos] = op.getOperation(); + } + return boundary_output_ops; + } + mlir::FailureOr> GetBoundaryMetadata( mlir::Operation* op) { if (!IsXlaMarkTensorOp(op)) { @@ -196,9 +223,21 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { } mlir::LogicalResult BuildStableHLOComposite( - mlir::Operation* op, + const llvm::SmallVector& output_ops, const llvm::DenseMap& op_order_map) { - auto metadata_or = GetBoundaryMetadata(op); + if (output_ops.empty()) { + return mlir::success(); + } + + // Get the output op with maximum order num as the representative. + mlir::Operation* last_output_op = output_ops[0]; + for (mlir::Operation* op : output_ops) { + if (op_order_map.at(op) > op_order_map.at(last_output_op)) { + last_output_op = op; + } + } + + auto metadata_or = GetBoundaryMetadata(last_output_op); if (mlir::failed(metadata_or)) { return mlir::failure(); } @@ -208,7 +247,8 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { return mlir::success(); } - auto args_ops_or = GetBoundaryArgsAndOps(op, *metadata, op_order_map); + auto args_ops_or = + GetBoundaryArgsAndOps(output_ops, *metadata, op_order_map); if (mlir::failed(args_ops_or)) { return mlir::failure(); } @@ -216,10 +256,10 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { auto [args, impl_ops] = *args_ops_or; mlir::func::FuncOp impl_func = BuildStableHLOCompositeImplFunc( - op, absl::StrCat(metadata->name, ".impl"), args, impl_ops); + output_ops, absl::StrCat(metadata->name, ".impl"), args, impl_ops); mlir::FailureOr composite_op_or = - BuildStableHLOCompositeOp(op, impl_func, args, *metadata); + BuildStableHLOCompositeOp(last_output_op, impl_func, args, *metadata); if (mlir::failed(composite_op_or)) { return mlir::failure(); } @@ -227,9 +267,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { // Updates all users of this op's result(s) to use the results(s) of impl // func call. - for (size_t i = 0; i < op->getNumResults(); ++i) { - mlir::OpResult result = op->getResult(i); - result.replaceAllUsesWith(composite_op->getResult(i)); + size_t next_composite_result_i = 0; + for (mlir::Operation* op : output_ops) { + for (size_t i = 0; i < op->getNumResults(); ++i) { + mlir::OpResult result = op->getResult(i); + result.replaceAllUsesWith( + composite_op->getResult(next_composite_result_i++)); + } } // The unused impl_ops will be eliminated with canonicalizer. @@ -239,11 +283,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { mlir::FailureOr, llvm::SmallVector>> GetBoundaryArgsAndOps( - mlir::Operation* boundary_output_op, const BoundaryMetadata& metadata, + const llvm::SmallVector boundary_output_ops, + const BoundaryMetadata& metadata, const llvm::DenseMap& op_order_map) { llvm::SetVector impl_ops_setvec; llvm::SetVector> arg_pos_setvec; - llvm::SmallVector processing({boundary_output_op}); + llvm::SmallVector processing(boundary_output_ops.begin(), + boundary_output_ops.end()); // Reverse graph traversal: from boundary output op to boundary input op, // global function arg, or stablehlo constant. @@ -318,8 +364,8 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { } mlir::func::FuncOp BuildStableHLOCompositeImplFunc( - mlir::Operation* boundary_output_op, llvm::StringRef func_name, - const llvm::SmallVector& args, + const llvm::SmallVector boundary_output_ops, + llvm::StringRef func_name, const llvm::SmallVector& args, const llvm::SmallVector& impl_ops) { mlir::ModuleOp module_op = getOperation(); mlir::MLIRContext* context = &getContext(); @@ -333,9 +379,11 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { arg_types.push_back(arg.getType()); arg_locs.push_back(arg.getLoc()); } - llvm::SmallVector result_types( - boundary_output_op->getResultTypes().begin(), - boundary_output_op->getResultTypes().end()); + llvm::SmallVector result_types; + for (mlir::Operation* op : boundary_output_ops) { + result_types.append(op->getResultTypes().begin(), + op->getResultTypes().end()); + } mlir::func::FuncOp impl_func = builder.create( module_op.getLoc(), func_name, @@ -350,9 +398,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { mlir::Operation* cloned_op = builder.clone(*original_op, mapping); mapping.map(original_op, cloned_op); } - builder.create( - impl_func.getBody().getLoc(), - mapping.lookup(boundary_output_op)->getResults()); + + llvm::SmallVector results; + for (mlir::Operation* op : boundary_output_ops) { + results.append(mapping.lookup(op)->getResults().begin(), + mapping.lookup(op)->getResults().end()); + } + builder.create(impl_func.getBody().getLoc(), results); // Adds the new function to symbol table. mlir::SymbolTable symbol_table(module_op); From 5e288beb0d543c894bb5bc92923b153789fadf1d Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Thu, 11 Jan 2024 00:13:08 -0800 Subject: [PATCH 2/5] update --- test/stablehlo/test_mark_pattern.py | 22 ++++++++++++++++++++ torch_xla/experimental/mark_pattern_utils.py | 4 ---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index cb1ecc5ae19..ae5dc3b264a 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -190,6 +190,28 @@ def forward(self, x, y): '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) + def test_composite_builder_multiple_outputs(self): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, y): + builder = StableHLOCompositeBuilder("sample_composite") + x, y = builder.mark_inputs(x, y) + a = x + y + b = x - y + c = x + 1 + a, b, c = builder.mark_outputs(a, b, c) + return a + b + c + + input_args = (torch.randn((5, 5)), torch.randn((5, 5))) + stablehlo = self.run_func_get_stablehlo(M(), input_args) + print(stablehlo) + self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) + self.assertTrue(False) + def test_multiple_inputs(self): def f(x, y): diff --git a/torch_xla/experimental/mark_pattern_utils.py b/torch_xla/experimental/mark_pattern_utils.py index 4665f9b14e0..317107c2e55 100644 --- a/torch_xla/experimental/mark_pattern_utils.py +++ b/torch_xla/experimental/mark_pattern_utils.py @@ -80,8 +80,4 @@ def mark_outputs(self, *tensors: torch.Tensor): should be replaced by the marked tensors in later usages. """ - if len(tensors) > 1: - # TODO: Allow multiple composite outputs - raise ValueError( - f"StableHLO composite with more than one outputs is not supported.") return self._mark_tensor(*tensors, is_input=False) From 7d5376c59ad03d75184578af799d1e98517fc17e Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Wed, 31 Jan 2024 22:07:29 -0800 Subject: [PATCH 3/5] fix --- test/stablehlo/test_mark_pattern.py | 2 -- .../csrc/runtime/stablehlo_composite_helper.cc | 17 ++++++++--------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index ae5dc3b264a..b2a7284cf6a 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -208,9 +208,7 @@ def forward(self, x, y): input_args = (torch.randn((5, 5)), torch.randn((5, 5))) stablehlo = self.run_func_get_stablehlo(M(), input_args) - print(stablehlo) self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) - self.assertTrue(False) def test_multiple_inputs(self): diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc index 34830aa15f1..413364513f3 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc @@ -229,15 +229,15 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { return mlir::success(); } - // Get the output op with maximum order num as the representative. - mlir::Operation* last_output_op = output_ops[0]; + // Get the output op with minimum order num as the representative. + mlir::Operation* first_output_op = output_ops[0]; for (mlir::Operation* op : output_ops) { - if (op_order_map.at(op) > op_order_map.at(last_output_op)) { - last_output_op = op; + if (op_order_map.at(op) < op_order_map.at(first_output_op)) { + first_output_op = op; } } - auto metadata_or = GetBoundaryMetadata(last_output_op); + auto metadata_or = GetBoundaryMetadata(first_output_op); if (mlir::failed(metadata_or)) { return mlir::failure(); } @@ -257,9 +257,8 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { mlir::func::FuncOp impl_func = BuildStableHLOCompositeImplFunc( output_ops, absl::StrCat(metadata->name, ".impl"), args, impl_ops); - mlir::FailureOr composite_op_or = - BuildStableHLOCompositeOp(last_output_op, impl_func, args, *metadata); + BuildStableHLOCompositeOp(first_output_op, impl_func, args, *metadata); if (mlir::failed(composite_op_or)) { return mlir::failure(); } @@ -267,12 +266,12 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { // Updates all users of this op's result(s) to use the results(s) of impl // func call. - size_t next_composite_result_i = 0; + size_t composite_result_i = 0; for (mlir::Operation* op : output_ops) { for (size_t i = 0; i < op->getNumResults(); ++i) { mlir::OpResult result = op->getResult(i); result.replaceAllUsesWith( - composite_op->getResult(next_composite_result_i++)); + composite_op->getResult(composite_result_i++)); } } From aeec3edae9c7e8dfefe6a61c7a59b38bb311a7ca Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Wed, 31 Jan 2024 22:18:53 -0800 Subject: [PATCH 4/5] Update test_mark_pattern.py --- test/stablehlo/test_mark_pattern.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index b2a7284cf6a..05b73189e0e 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -54,9 +54,6 @@ def test_sdpa_pattern(self): class M(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): q, k, v = x.split(128, dim=-2) q = torch.ops.xla_pattern_marking.mark_tensor( @@ -98,9 +95,6 @@ def test_composite_builder_sdpa_pattern(self): class M(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25}) q, k, v = x.split(128, dim=-2) @@ -161,9 +155,6 @@ def test_inlined_composite_builder_export_sdpa_pattern(self): class M(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25}) q, k, v = x.split(128, dim=-2) @@ -194,9 +185,6 @@ def test_composite_builder_multiple_outputs(self): class M(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): builder = StableHLOCompositeBuilder("sample_composite") x, y = builder.mark_inputs(x, y) From a6f200164214ecdac199ec6cd5f7144a9fd522c4 Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Sun, 4 Feb 2024 07:50:09 -0800 Subject: [PATCH 5/5] Update stablehlo_composite_helper.cc --- torch_xla/csrc/runtime/stablehlo_composite_helper.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc index 413364513f3..0f0492dac0d 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc @@ -244,7 +244,9 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { std::unique_ptr metadata = std::move(*metadata_or); if (metadata == nullptr || metadata->is_input) { - return mlir::success(); + // There should always be a valid boundary output metadata associated with + // each op in output_ops. + return mlir::failure(); } auto args_ops_or =