diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index 0e635da1e4a..c084357fa52 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -54,9 +54,6 @@ def test_sdpa_pattern(self): class M(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): q, k, v = x.split(128, dim=-2) q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="0", is_input=True) @@ -92,9 +89,6 @@ def test_composite_builder_sdpa_pattern(self): class M(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25}) q, k, v = x.split(128, dim=-2) @@ -155,9 +149,6 @@ def test_inlined_composite_builder_export_sdpa_pattern(self): class M(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25}) q, k, v = x.split(128, dim=-2) @@ -184,7 +175,24 @@ def forward(self, x, y): '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) - def test_multiple_input(self): + def test_composite_builder_multiple_outputs(self): + + class M(torch.nn.Module): + + def forward(self, x, y): + builder = StableHLOCompositeBuilder("sample_composite") + x, y = builder.mark_inputs(x, y) + a = x + y + b = x - y + c = x + 1 + a, b, c = builder.mark_outputs(a, b, c) + return a + b + c + + input_args = (torch.randn((5, 5)), torch.randn((5, 5))) + stablehlo = self.run_func_get_stablehlo(M(), input_args) + self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) + + def test_multiple_inputs(self): def f(x, y): x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True) @@ -199,8 +207,7 @@ def f(x, y): self.assertEqual(stablehlo.count("@stablehlo.composite"), 1) self.assertTrue('{attributes = {}, name = "p"}' in stablehlo) - @unittest.skip("Multiple outputs patterns are not supported now.") - def test_multiple_output(self): + def test_multiple_outputs(self): def f(x, y): x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True) diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc index bc38fc65997..0f0492dac0d 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc @@ -44,7 +44,7 @@ struct BoundaryMetadata { bool is_input; std::unordered_map attrs; - auto boundary_key() const { return std::forward_as_tuple(name, id); } + auto boundary_key() const { return absl::StrCat(name, "__@@__", id); } auto uid() const { return std::forward_as_tuple(name, id, pos, is_input); } @@ -116,10 +116,12 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { for (mlir::func::FuncOp& func_op : func_ops) { llvm::DenseMap op_order_map = BuildOpOrderMap(func_op); - for (auto op : func_op.getOps()) { - if (mlir::failed( - BuildStableHLOComposite(op.getOperation(), op_order_map))) { - op.emitError() << "failed to build composite."; + std::unordered_map> + boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op); + + for (const auto& [unused, ops] : boundary_output_ops_map) { + if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) { + func_op.emitError() << "failed to build composite."; return signalPassFailure(); } } @@ -144,6 +146,31 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { return op_order_map; } + std::unordered_map> + BuildBoundaryOutputOpsMap(mlir::func::FuncOp func_op) { + std::unordered_map> + boundary_output_ops; + + for (auto op : func_op.getOps()) { + auto metadata_or = GetBoundaryMetadata(op); + if (mlir::failed(metadata_or)) { + continue; + } + + std::unique_ptr metadata = std::move(*metadata_or); + if (metadata == nullptr || metadata->is_input) { + continue; + } + + auto& output_ops = boundary_output_ops[metadata->boundary_key()]; + if (metadata->pos >= output_ops.size()) { + output_ops.resize(metadata->pos + 1, nullptr); + } + output_ops[metadata->pos] = op.getOperation(); + } + return boundary_output_ops; + } + mlir::FailureOr> GetBoundaryMetadata( mlir::Operation* op) { if (!IsXlaMarkTensorOp(op)) { @@ -196,19 +223,34 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { } mlir::LogicalResult BuildStableHLOComposite( - mlir::Operation* op, + const llvm::SmallVector& output_ops, const llvm::DenseMap& op_order_map) { - auto metadata_or = GetBoundaryMetadata(op); + if (output_ops.empty()) { + return mlir::success(); + } + + // Get the output op with minimum order num as the representative. + mlir::Operation* first_output_op = output_ops[0]; + for (mlir::Operation* op : output_ops) { + if (op_order_map.at(op) < op_order_map.at(first_output_op)) { + first_output_op = op; + } + } + + auto metadata_or = GetBoundaryMetadata(first_output_op); if (mlir::failed(metadata_or)) { return mlir::failure(); } std::unique_ptr metadata = std::move(*metadata_or); if (metadata == nullptr || metadata->is_input) { - return mlir::success(); + // There should always be a valid boundary output metadata associated with + // each op in output_ops. + return mlir::failure(); } - auto args_ops_or = GetBoundaryArgsAndOps(op, *metadata, op_order_map); + auto args_ops_or = + GetBoundaryArgsAndOps(output_ops, *metadata, op_order_map); if (mlir::failed(args_ops_or)) { return mlir::failure(); } @@ -216,10 +258,9 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { auto [args, impl_ops] = *args_ops_or; mlir::func::FuncOp impl_func = BuildStableHLOCompositeImplFunc( - op, absl::StrCat(metadata->name, ".impl"), args, impl_ops); - + output_ops, absl::StrCat(metadata->name, ".impl"), args, impl_ops); mlir::FailureOr composite_op_or = - BuildStableHLOCompositeOp(op, impl_func, args, *metadata); + BuildStableHLOCompositeOp(first_output_op, impl_func, args, *metadata); if (mlir::failed(composite_op_or)) { return mlir::failure(); } @@ -227,9 +268,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { // Updates all users of this op's result(s) to use the results(s) of impl // func call. - for (size_t i = 0; i < op->getNumResults(); ++i) { - mlir::OpResult result = op->getResult(i); - result.replaceAllUsesWith(composite_op->getResult(i)); + size_t composite_result_i = 0; + for (mlir::Operation* op : output_ops) { + for (size_t i = 0; i < op->getNumResults(); ++i) { + mlir::OpResult result = op->getResult(i); + result.replaceAllUsesWith( + composite_op->getResult(composite_result_i++)); + } } // The unused impl_ops will be eliminated with canonicalizer. @@ -239,11 +284,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { mlir::FailureOr, llvm::SmallVector>> GetBoundaryArgsAndOps( - mlir::Operation* boundary_output_op, const BoundaryMetadata& metadata, + const llvm::SmallVector boundary_output_ops, + const BoundaryMetadata& metadata, const llvm::DenseMap& op_order_map) { llvm::SetVector impl_ops_setvec; llvm::SetVector> arg_pos_setvec; - llvm::SmallVector processing({boundary_output_op}); + llvm::SmallVector processing(boundary_output_ops.begin(), + boundary_output_ops.end()); // Reverse graph traversal: from boundary output op to boundary input op, // global function arg, or stablehlo constant. @@ -318,8 +365,8 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { } mlir::func::FuncOp BuildStableHLOCompositeImplFunc( - mlir::Operation* boundary_output_op, llvm::StringRef func_name, - const llvm::SmallVector& args, + const llvm::SmallVector boundary_output_ops, + llvm::StringRef func_name, const llvm::SmallVector& args, const llvm::SmallVector& impl_ops) { mlir::ModuleOp module_op = getOperation(); mlir::MLIRContext* context = &getContext(); @@ -333,9 +380,11 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { arg_types.push_back(arg.getType()); arg_locs.push_back(arg.getLoc()); } - llvm::SmallVector result_types( - boundary_output_op->getResultTypes().begin(), - boundary_output_op->getResultTypes().end()); + llvm::SmallVector result_types; + for (mlir::Operation* op : boundary_output_ops) { + result_types.append(op->getResultTypes().begin(), + op->getResultTypes().end()); + } mlir::func::FuncOp impl_func = builder.create( module_op.getLoc(), func_name, @@ -350,9 +399,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { mlir::Operation* cloned_op = builder.clone(*original_op, mapping); mapping.map(original_op, cloned_op); } - builder.create( - impl_func.getBody().getLoc(), - mapping.lookup(boundary_output_op)->getResults()); + + llvm::SmallVector results; + for (mlir::Operation* op : boundary_output_ops) { + results.append(mapping.lookup(op)->getResults().begin(), + mapping.lookup(op)->getResults().end()); + } + builder.create(impl_func.getBody().getLoc(), results); // Adds the new function to symbol table. mlir::SymbolTable symbol_table(module_op); diff --git a/torch_xla/experimental/mark_pattern_utils.py b/torch_xla/experimental/mark_pattern_utils.py index 2e0baa129f1..d67d2c6996b 100644 --- a/torch_xla/experimental/mark_pattern_utils.py +++ b/torch_xla/experimental/mark_pattern_utils.py @@ -80,8 +80,4 @@ def mark_outputs(self, *tensors: torch.Tensor): should be replaced by the marked tensors in later usages. """ - if len(tensors) > 1: - # TODO: Allow multiple composite outputs - raise ValueError( - f"StableHLO composite with more than one outputs is not supported.") return self._mark_tensor(*tensors, is_input=False)