Skip to content

Commit

Permalink
Support multiple StableHLO Composite outputs (#6295)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc authored Feb 4, 2024
1 parent 8fc8d57 commit 535d398
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 41 deletions.
31 changes: 19 additions & 12 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def test_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="0", is_input=True)
Expand Down Expand Up @@ -92,9 +89,6 @@ def test_composite_builder_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
q, k, v = x.split(128, dim=-2)
Expand Down Expand Up @@ -155,9 +149,6 @@ def test_inlined_composite_builder_export_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
q, k, v = x.split(128, dim=-2)
Expand All @@ -184,7 +175,24 @@ def forward(self, x, y):
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

def test_multiple_input(self):
def test_composite_builder_multiple_outputs(self):

class M(torch.nn.Module):

def forward(self, x, y):
builder = StableHLOCompositeBuilder("sample_composite")
x, y = builder.mark_inputs(x, y)
a = x + y
b = x - y
c = x + 1
a, b, c = builder.mark_outputs(a, b, c)
return a + b + c

input_args = (torch.randn((5, 5)), torch.randn((5, 5)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)

def test_multiple_inputs(self):

def f(x, y):
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
Expand All @@ -199,8 +207,7 @@ def f(x, y):
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertTrue('{attributes = {}, name = "p"}' in stablehlo)

@unittest.skip("Multiple outputs patterns are not supported now.")
def test_multiple_output(self):
def test_multiple_outputs(self):

def f(x, y):
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
Expand Down
103 changes: 78 additions & 25 deletions torch_xla/csrc/runtime/stablehlo_composite_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct BoundaryMetadata {
bool is_input;
std::unordered_map<std::string, json> attrs;

auto boundary_key() const { return std::forward_as_tuple(name, id); }
auto boundary_key() const { return absl::StrCat(name, "__@@__", id); }

auto uid() const { return std::forward_as_tuple(name, id, pos, is_input); }

Expand Down Expand Up @@ -116,10 +116,12 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
for (mlir::func::FuncOp& func_op : func_ops) {
llvm::DenseMap<const mlir::Operation*, size_t> op_order_map =
BuildOpOrderMap(func_op);
for (auto op : func_op.getOps<mlir::stablehlo::CustomCallOp>()) {
if (mlir::failed(
BuildStableHLOComposite(op.getOperation(), op_order_map))) {
op.emitError() << "failed to build composite.";
std::unordered_map<std::string, llvm::SmallVector<mlir::Operation*>>
boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op);

for (const auto& [unused, ops] : boundary_output_ops_map) {
if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) {
func_op.emitError() << "failed to build composite.";
return signalPassFailure();
}
}
Expand All @@ -144,6 +146,31 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
return op_order_map;
}

std::unordered_map<std::string, llvm::SmallVector<mlir::Operation*>>
BuildBoundaryOutputOpsMap(mlir::func::FuncOp func_op) {
std::unordered_map<std::string, llvm::SmallVector<mlir::Operation*>>
boundary_output_ops;

for (auto op : func_op.getOps<mlir::stablehlo::CustomCallOp>()) {
auto metadata_or = GetBoundaryMetadata(op);
if (mlir::failed(metadata_or)) {
continue;
}

std::unique_ptr<BoundaryMetadata> metadata = std::move(*metadata_or);
if (metadata == nullptr || metadata->is_input) {
continue;
}

auto& output_ops = boundary_output_ops[metadata->boundary_key()];
if (metadata->pos >= output_ops.size()) {
output_ops.resize(metadata->pos + 1, nullptr);
}
output_ops[metadata->pos] = op.getOperation();
}
return boundary_output_ops;
}

mlir::FailureOr<std::unique_ptr<BoundaryMetadata>> GetBoundaryMetadata(
mlir::Operation* op) {
if (!IsXlaMarkTensorOp(op)) {
Expand Down Expand Up @@ -196,40 +223,58 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
}

mlir::LogicalResult BuildStableHLOComposite(
mlir::Operation* op,
const llvm::SmallVector<mlir::Operation*>& output_ops,
const llvm::DenseMap<const mlir::Operation*, size_t>& op_order_map) {
auto metadata_or = GetBoundaryMetadata(op);
if (output_ops.empty()) {
return mlir::success();
}

// Get the output op with minimum order num as the representative.
mlir::Operation* first_output_op = output_ops[0];
for (mlir::Operation* op : output_ops) {
if (op_order_map.at(op) < op_order_map.at(first_output_op)) {
first_output_op = op;
}
}

auto metadata_or = GetBoundaryMetadata(first_output_op);
if (mlir::failed(metadata_or)) {
return mlir::failure();
}

std::unique_ptr<BoundaryMetadata> metadata = std::move(*metadata_or);
if (metadata == nullptr || metadata->is_input) {
return mlir::success();
// There should always be a valid boundary output metadata associated with
// each op in output_ops.
return mlir::failure();
}

auto args_ops_or = GetBoundaryArgsAndOps(op, *metadata, op_order_map);
auto args_ops_or =
GetBoundaryArgsAndOps(output_ops, *metadata, op_order_map);
if (mlir::failed(args_ops_or)) {
return mlir::failure();
}

auto [args, impl_ops] = *args_ops_or;

mlir::func::FuncOp impl_func = BuildStableHLOCompositeImplFunc(
op, absl::StrCat(metadata->name, ".impl"), args, impl_ops);

output_ops, absl::StrCat(metadata->name, ".impl"), args, impl_ops);
mlir::FailureOr<mlir::Operation*> composite_op_or =
BuildStableHLOCompositeOp(op, impl_func, args, *metadata);
BuildStableHLOCompositeOp(first_output_op, impl_func, args, *metadata);
if (mlir::failed(composite_op_or)) {
return mlir::failure();
}
mlir::Operation* composite_op = *composite_op_or;

// Updates all users of this op's result(s) to use the results(s) of impl
// func call.
for (size_t i = 0; i < op->getNumResults(); ++i) {
mlir::OpResult result = op->getResult(i);
result.replaceAllUsesWith(composite_op->getResult(i));
size_t composite_result_i = 0;
for (mlir::Operation* op : output_ops) {
for (size_t i = 0; i < op->getNumResults(); ++i) {
mlir::OpResult result = op->getResult(i);
result.replaceAllUsesWith(
composite_op->getResult(composite_result_i++));
}
}

// The unused impl_ops will be eliminated with canonicalizer.
Expand All @@ -239,11 +284,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
mlir::FailureOr<std::pair<llvm::SmallVector<mlir::Value>,
llvm::SmallVector<mlir::Operation*>>>
GetBoundaryArgsAndOps(
mlir::Operation* boundary_output_op, const BoundaryMetadata& metadata,
const llvm::SmallVector<mlir::Operation*> boundary_output_ops,
const BoundaryMetadata& metadata,
const llvm::DenseMap<const mlir::Operation*, size_t>& op_order_map) {
llvm::SetVector<mlir::Operation*> impl_ops_setvec;
llvm::SetVector<std::pair<mlir::Value, int64_t>> arg_pos_setvec;
llvm::SmallVector<mlir::Operation*> processing({boundary_output_op});
llvm::SmallVector<mlir::Operation*> processing(boundary_output_ops.begin(),
boundary_output_ops.end());

// Reverse graph traversal: from boundary output op to boundary input op,
// global function arg, or stablehlo constant.
Expand Down Expand Up @@ -318,8 +365,8 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
}

mlir::func::FuncOp BuildStableHLOCompositeImplFunc(
mlir::Operation* boundary_output_op, llvm::StringRef func_name,
const llvm::SmallVector<mlir::Value>& args,
const llvm::SmallVector<mlir::Operation*> boundary_output_ops,
llvm::StringRef func_name, const llvm::SmallVector<mlir::Value>& args,
const llvm::SmallVector<mlir::Operation*>& impl_ops) {
mlir::ModuleOp module_op = getOperation();
mlir::MLIRContext* context = &getContext();
Expand All @@ -333,9 +380,11 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
arg_types.push_back(arg.getType());
arg_locs.push_back(arg.getLoc());
}
llvm::SmallVector<mlir::Type> result_types(
boundary_output_op->getResultTypes().begin(),
boundary_output_op->getResultTypes().end());
llvm::SmallVector<mlir::Type> result_types;
for (mlir::Operation* op : boundary_output_ops) {
result_types.append(op->getResultTypes().begin(),
op->getResultTypes().end());
}

mlir::func::FuncOp impl_func = builder.create<mlir::func::FuncOp>(
module_op.getLoc(), func_name,
Expand All @@ -350,9 +399,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
mlir::Operation* cloned_op = builder.clone(*original_op, mapping);
mapping.map(original_op, cloned_op);
}
builder.create<mlir::func::ReturnOp>(
impl_func.getBody().getLoc(),
mapping.lookup(boundary_output_op)->getResults());

llvm::SmallVector<mlir::Value> results;
for (mlir::Operation* op : boundary_output_ops) {
results.append(mapping.lookup(op)->getResults().begin(),
mapping.lookup(op)->getResults().end());
}
builder.create<mlir::func::ReturnOp>(impl_func.getBody().getLoc(), results);

// Adds the new function to symbol table.
mlir::SymbolTable symbol_table(module_op);
Expand Down
4 changes: 0 additions & 4 deletions torch_xla/experimental/mark_pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,4 @@ def mark_outputs(self, *tensors: torch.Tensor):
should be replaced by the marked tensors in later usages.
"""

if len(tensors) > 1:
# TODO: Allow multiple composite outputs
raise ValueError(
f"StableHLO composite with more than one outputs is not supported.")
return self._mark_tensor(*tensors, is_input=False)

0 comments on commit 535d398

Please sign in to comment.