Skip to content

Commit

Permalink
Add experimental MLIR debuginfo writer API (#6799)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc authored Mar 26, 2024
1 parent 1ad6bb4 commit 64a146d
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/stablehlo/test_mlir_debuginfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
import re

import torch
import torch_xla
import torch_xla.experimental.xla_mlir_debuginfo
from torch_xla.stablehlo import exported_program_to_stablehlo


class XlaMlirDebuginfoTest(unittest.TestCase):

def test_write_debuginfo(self):

class SampleModel(torch.nn.Module):

def forward(self, x, y):
x = x + y
x = torch.ops.xla.write_mlir_debuginfo(x, "MY_ADD")
x = x - y
x = torch.ops.xla.write_mlir_debuginfo(x, "MY_SUB")
return x

model = SampleModel()
exported_program = torch.export.export(model,
(torch.rand(10), torch.rand(10)))
mlir_text = exported_program_to_stablehlo(
exported_program).get_stablehlo_text()
self.assertTrue(re.search(r'stablehlo.add.+\"MY_ADD\"', mlir_text))
self.assertTrue(re.search(r'stablehlo.sub.+\"MY_SUB\"', mlir_text))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
12 changes: 12 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,25 @@ cc_library(
],
)

cc_library(
name = "xla_mlir_debuginfo_helper",
srcs = ["xla_mlir_debuginfo_helper.cc"],
hdrs = ["xla_mlir_debuginfo_helper.h"],
deps = [
":types",
":xla_util",
"@xla//xla/mlir_hlo:all_passes",
],
)

cc_library(
name = "stablehlo_helper",
srcs = ["stablehlo_helper.cc"],
hdrs = ["stablehlo_helper.h"],
deps = [
":types",
":xla_util",
":xla_mlir_debuginfo_helper",
":stablehlo_composite_helper",
"@stablehlo//:stablehlo_portable_api",
"@stablehlo//:stablehlo_serialization",
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/stablehlo_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/stablehlo_composite_helper.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/xla_mlir_debuginfo_helper.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
Expand Down Expand Up @@ -68,6 +69,7 @@ static absl::Status ConvertHloToMhlo(const xla::HloModuleProto* proto,
static absl::Status mhloToStablehloHelper(mlir::ModuleOp* mlir_module,
mlir::MLIRContext* context) {
mlir::PassManager pm(context);
pm.addPass(torch_xla::runtime::CreatePrepareXlaMlirDebuginfoPass());
// legalize `mhlo.dot` to `mhlo.dot_general` to workaround the shape
// refinement issue in `stablehlo.dot`.
// TODO(lsy323): Remove this pass when mhlo.dot will can be leagalized to
Expand Down
92 changes: 92 additions & 0 deletions torch_xla/csrc/runtime/xla_mlir_debuginfo_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "torch_xla/csrc/runtime/xla_mlir_debuginfo_helper.h"

#include <cstring>
#include <string>
#include <utility>

#include "absl/log/log.h"

namespace torch_xla {
namespace runtime {

namespace {

// Defined in torch_xla/experimental/xla_mlir_debuginfo.py
static constexpr char XLA_MLIR_DEBUGINFO_BEGIN[] = "<XLA_MLIR_DEBUGINFO_BEGIN>";
static constexpr char XLA_MLIR_DEBUGINFO_END[] = "<XLA_MLIR_DEBUGINFO_END>";

class PrepareXlaMlirDebuginfoPass : public mlir::OperationPass<mlir::ModuleOp> {
public:
explicit PrepareXlaMlirDebuginfoPass()
: mlir::OperationPass<mlir::ModuleOp>::OperationPass(
mlir::TypeID::get<PrepareXlaMlirDebuginfoPass>()) {}

~PrepareXlaMlirDebuginfoPass() override = default;

void runOnOperation() override {
mlir::MLIRContext* context = &getContext();
getOperation().walk([&](mlir::Operation* op) {
llvm::SmallVector<std::string> debuginfos;
ExtractXlaMlirDebuginfo(op->getLoc(), debuginfos);

if (!debuginfos.empty()) {
// If multiple debuginfos are found (which should be an exception),
// pick arbitrary one and discard the rest;
const std::string& debuginfo = debuginfos[0];
op->setLoc(
mlir::NameLoc::get(mlir::StringAttr::get(context, debuginfo)));
}
// TODO: Remove unspecified locations when a global flag is set.
});
}

void ExtractXlaMlirDebuginfo(mlir::Location loc,
llvm::SmallVector<std::string>& debuginfos) {
if (loc.isa<mlir::FusedLoc>()) {
for (mlir::Location subloc :
loc.dyn_cast<mlir::FusedLoc>().getLocations()) {
ExtractXlaMlirDebuginfo(subloc, debuginfos);
}
}
if (loc.isa<mlir::NameLoc>()) {
std::string name(loc.dyn_cast<mlir::NameLoc>().getName().str());

for (size_t i = 0; i < name.size();) {
size_t begin = name.find(XLA_MLIR_DEBUGINFO_BEGIN, i);
if (begin == std::string::npos) {
break;
}
begin += strlen(XLA_MLIR_DEBUGINFO_BEGIN);
size_t end = name.find(XLA_MLIR_DEBUGINFO_END, begin);
if (end == std::string::npos) {
break;
}

std::string debuginfo = name.substr(begin, end - begin);
debuginfos.push_back(std::move(debuginfo));

i = end + strlen(XLA_MLIR_DEBUGINFO_BEGIN);
}
}
// TODO: Handle other loc types if debuginfo can be propagated/nested in
// other loc type.
}

mlir::StringRef getName() const override {
return llvm::getTypeName<PrepareXlaMlirDebuginfoPass>();
}

std::unique_ptr<mlir::Pass> clonePass() const override {
return std::make_unique<PrepareXlaMlirDebuginfoPass>(*this);
}
};

} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreatePrepareXlaMlirDebuginfoPass() {
return std::make_unique<PrepareXlaMlirDebuginfoPass>();
}

} // namespace runtime
} // namespace torch_xla
17 changes: 17 additions & 0 deletions torch_xla/csrc/runtime/xla_mlir_debuginfo_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef XLA_MLIR_DEBUGINFO_HELPER_H_
#define XLA_MLIR_DEBUGINFO_HELPER_H_

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace torch_xla {
namespace runtime {

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreatePrepareXlaMlirDebuginfoPass();

} // namespace runtime
} // namespace torch_xla

#endif
40 changes: 40 additions & 0 deletions torch_xla/experimental/xla_mlir_debuginfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os

import torch
import torch_xla
from torch_xla.core import xla_model as xm
from torch_xla.core.xla_model import XLA_LIB

# Enable debug info automatically when importing this file. This is necessary
# to propagate any debug info to downstream MLIR locations.
os.environ["XLA_HLO_DEBUG"] = "1"
xla_device = xm.xla_device()

XLA_LIB.define("write_mlir_debuginfo(Tensor x, str data) -> Tensor")


@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", "XLA")
def write_mlir_debuginfo(x, data: str):
begin_token = "<XLA_MLIR_DEBUGINFO_BEGIN>"
end_token = "<XLA_MLIR_DEBUGINFO_END>"
# Add the debuginfo string as the op prefix in MLIR location, surrounded
# by begin and end tokens. The tokens and suffix op name will be removed
# in the downstream pass PrepareXlaMlirDebuginfoPass after converting
# HLO proto to MLIR.
torch_xla._XLAC._set_xla_custom_op_name_prefix(
x,
begin_token + data + end_token,
0,
)
return x


@torch.library.impl(XLA_LIB, "write_mlir_debuginfo",
"CompositeExplicitAutograd")
def write_mlir_debuginfo(x, data: str):
return x


@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", "Meta")
def write_mlir_debuginfo_meta(x, data: str):
return torch.empty_like(x)

0 comments on commit 64a146d

Please sign in to comment.