Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple StableHLO Composite outputs #6295

Merged
merged 6 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def test_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
Expand Down Expand Up @@ -98,9 +95,6 @@ def test_composite_builder_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
q, k, v = x.split(128, dim=-2)
Expand Down Expand Up @@ -161,9 +155,6 @@ def test_inlined_composite_builder_export_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
q, k, v = x.split(128, dim=-2)
Expand All @@ -190,7 +181,24 @@ def forward(self, x, y):
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

def test_multiple_input(self):
def test_composite_builder_multiple_outputs(self):

class M(torch.nn.Module):

def forward(self, x, y):
builder = StableHLOCompositeBuilder("sample_composite")
x, y = builder.mark_inputs(x, y)
a = x + y
b = x - y
c = x + 1
a, b, c = builder.mark_outputs(a, b, c)
return a + b + c

input_args = (torch.randn((5, 5)), torch.randn((5, 5)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)

def test_multiple_inputs(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
Expand All @@ -205,8 +213,7 @@ def f(x, y):
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertTrue('{attributes = {}, name = "p"}' in stablehlo)

@unittest.skip("Multiple outputs patterns are not supported now.")
def test_multiple_output(self):
def test_multiple_outputs(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
Expand Down
103 changes: 78 additions & 25 deletions torch_xla/csrc/runtime/stablehlo_composite_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct BoundaryMetadata {
bool is_input;
std::unordered_map<std::string, json> attrs;

auto boundary_key() const { return std::forward_as_tuple(name, id); }
auto boundary_key() const { return absl::StrCat(name, "__@@__", id); }

auto uid() const { return std::forward_as_tuple(name, id, pos, is_input); }

Expand Down Expand Up @@ -116,10 +116,12 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
for (mlir::func::FuncOp& func_op : func_ops) {
llvm::DenseMap<const mlir::Operation*, size_t> op_order_map =
BuildOpOrderMap(func_op);
for (auto op : func_op.getOps<mlir::stablehlo::CustomCallOp>()) {
if (mlir::failed(
BuildStableHLOComposite(op.getOperation(), op_order_map))) {
op.emitError() << "failed to build composite.";
std::unordered_map<std::string, llvm::SmallVector<mlir::Operation*>>
boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op);

for (const auto& [unused, ops] : boundary_output_ops_map) {
if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) {
func_op.emitError() << "failed to build composite.";
return signalPassFailure();
}
}
Expand All @@ -144,6 +146,31 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
return op_order_map;
}

std::unordered_map<std::string, llvm::SmallVector<mlir::Operation*>>
BuildBoundaryOutputOpsMap(mlir::func::FuncOp func_op) {
std::unordered_map<std::string, llvm::SmallVector<mlir::Operation*>>
boundary_output_ops;

for (auto op : func_op.getOps<mlir::stablehlo::CustomCallOp>()) {
auto metadata_or = GetBoundaryMetadata(op);
if (mlir::failed(metadata_or)) {
continue;
}

std::unique_ptr<BoundaryMetadata> metadata = std::move(*metadata_or);
if (metadata == nullptr || metadata->is_input) {
continue;
}

auto& output_ops = boundary_output_ops[metadata->boundary_key()];
if (metadata->pos >= output_ops.size()) {
output_ops.resize(metadata->pos + 1, nullptr);
}
output_ops[metadata->pos] = op.getOperation();
}
return boundary_output_ops;
}

mlir::FailureOr<std::unique_ptr<BoundaryMetadata>> GetBoundaryMetadata(
mlir::Operation* op) {
if (!IsXlaMarkTensorOp(op)) {
Expand Down Expand Up @@ -196,40 +223,58 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
}

mlir::LogicalResult BuildStableHLOComposite(
mlir::Operation* op,
const llvm::SmallVector<mlir::Operation*>& output_ops,
const llvm::DenseMap<const mlir::Operation*, size_t>& op_order_map) {
auto metadata_or = GetBoundaryMetadata(op);
if (output_ops.empty()) {
return mlir::success();
}

// Get the output op with minimum order num as the representative.
mlir::Operation* first_output_op = output_ops[0];
for (mlir::Operation* op : output_ops) {
if (op_order_map.at(op) < op_order_map.at(first_output_op)) {
first_output_op = op;
}
}

auto metadata_or = GetBoundaryMetadata(first_output_op);
if (mlir::failed(metadata_or)) {
return mlir::failure();
}

std::unique_ptr<BoundaryMetadata> metadata = std::move(*metadata_or);
if (metadata == nullptr || metadata->is_input) {
return mlir::success();
// There should always be a valid boundary output metadata associated with
// each op in output_ops.
return mlir::failure();
}
chunnienc marked this conversation as resolved.
Show resolved Hide resolved

auto args_ops_or = GetBoundaryArgsAndOps(op, *metadata, op_order_map);
auto args_ops_or =
GetBoundaryArgsAndOps(output_ops, *metadata, op_order_map);
if (mlir::failed(args_ops_or)) {
return mlir::failure();
}

auto [args, impl_ops] = *args_ops_or;

mlir::func::FuncOp impl_func = BuildStableHLOCompositeImplFunc(
op, absl::StrCat(metadata->name, ".impl"), args, impl_ops);

output_ops, absl::StrCat(metadata->name, ".impl"), args, impl_ops);
mlir::FailureOr<mlir::Operation*> composite_op_or =
BuildStableHLOCompositeOp(op, impl_func, args, *metadata);
BuildStableHLOCompositeOp(first_output_op, impl_func, args, *metadata);
if (mlir::failed(composite_op_or)) {
return mlir::failure();
}
mlir::Operation* composite_op = *composite_op_or;

// Updates all users of this op's result(s) to use the results(s) of impl
// func call.
for (size_t i = 0; i < op->getNumResults(); ++i) {
mlir::OpResult result = op->getResult(i);
result.replaceAllUsesWith(composite_op->getResult(i));
size_t composite_result_i = 0;
for (mlir::Operation* op : output_ops) {
for (size_t i = 0; i < op->getNumResults(); ++i) {
mlir::OpResult result = op->getResult(i);
result.replaceAllUsesWith(
composite_op->getResult(composite_result_i++));
}
}

// The unused impl_ops will be eliminated with canonicalizer.
Expand All @@ -239,11 +284,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
mlir::FailureOr<std::pair<llvm::SmallVector<mlir::Value>,
llvm::SmallVector<mlir::Operation*>>>
GetBoundaryArgsAndOps(
mlir::Operation* boundary_output_op, const BoundaryMetadata& metadata,
const llvm::SmallVector<mlir::Operation*> boundary_output_ops,
const BoundaryMetadata& metadata,
const llvm::DenseMap<const mlir::Operation*, size_t>& op_order_map) {
llvm::SetVector<mlir::Operation*> impl_ops_setvec;
llvm::SetVector<std::pair<mlir::Value, int64_t>> arg_pos_setvec;
llvm::SmallVector<mlir::Operation*> processing({boundary_output_op});
llvm::SmallVector<mlir::Operation*> processing(boundary_output_ops.begin(),
boundary_output_ops.end());

// Reverse graph traversal: from boundary output op to boundary input op,
// global function arg, or stablehlo constant.
Expand Down Expand Up @@ -318,8 +365,8 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
}

mlir::func::FuncOp BuildStableHLOCompositeImplFunc(
mlir::Operation* boundary_output_op, llvm::StringRef func_name,
const llvm::SmallVector<mlir::Value>& args,
const llvm::SmallVector<mlir::Operation*> boundary_output_ops,
llvm::StringRef func_name, const llvm::SmallVector<mlir::Value>& args,
const llvm::SmallVector<mlir::Operation*>& impl_ops) {
mlir::ModuleOp module_op = getOperation();
mlir::MLIRContext* context = &getContext();
Expand All @@ -333,9 +380,11 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
arg_types.push_back(arg.getType());
arg_locs.push_back(arg.getLoc());
}
llvm::SmallVector<mlir::Type> result_types(
boundary_output_op->getResultTypes().begin(),
boundary_output_op->getResultTypes().end());
llvm::SmallVector<mlir::Type> result_types;
for (mlir::Operation* op : boundary_output_ops) {
result_types.append(op->getResultTypes().begin(),
op->getResultTypes().end());
}

mlir::func::FuncOp impl_func = builder.create<mlir::func::FuncOp>(
module_op.getLoc(), func_name,
Expand All @@ -350,9 +399,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
mlir::Operation* cloned_op = builder.clone(*original_op, mapping);
chunnienc marked this conversation as resolved.
Show resolved Hide resolved
mapping.map(original_op, cloned_op);
}
builder.create<mlir::func::ReturnOp>(
impl_func.getBody().getLoc(),
mapping.lookup(boundary_output_op)->getResults());

llvm::SmallVector<mlir::Value> results;
for (mlir::Operation* op : boundary_output_ops) {
results.append(mapping.lookup(op)->getResults().begin(),
mapping.lookup(op)->getResults().end());
}
builder.create<mlir::func::ReturnOp>(impl_func.getBody().getLoc(), results);

// Adds the new function to symbol table.
mlir::SymbolTable symbol_table(module_op);
Expand Down
4 changes: 0 additions & 4 deletions torch_xla/experimental/mark_pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,4 @@ def mark_outputs(self, *tensors: torch.Tensor):
should be replaced by the marked tensors in later usages.
"""

if len(tensors) > 1:
# TODO: Allow multiple composite outputs
raise ValueError(
f"StableHLO composite with more than one outputs is not supported.")
return self._mark_tensor(*tensors, is_input=False)
Loading