-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add experimental MLIR debuginfo writer API (#6799)
- Loading branch information
Showing
6 changed files
with
197 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import unittest | ||
import re | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.experimental.xla_mlir_debuginfo | ||
from torch_xla.stablehlo import exported_program_to_stablehlo | ||
|
||
|
||
class XlaMlirDebuginfoTest(unittest.TestCase): | ||
|
||
def test_write_debuginfo(self): | ||
|
||
class SampleModel(torch.nn.Module): | ||
|
||
def forward(self, x, y): | ||
x = x + y | ||
x = torch.ops.xla.write_mlir_debuginfo(x, "MY_ADD") | ||
x = x - y | ||
x = torch.ops.xla.write_mlir_debuginfo(x, "MY_SUB") | ||
return x | ||
|
||
model = SampleModel() | ||
exported_program = torch.export.export(model, | ||
(torch.rand(10), torch.rand(10))) | ||
mlir_text = exported_program_to_stablehlo( | ||
exported_program).get_stablehlo_text() | ||
self.assertTrue(re.search(r'stablehlo.add.+\"MY_ADD\"', mlir_text)) | ||
self.assertTrue(re.search(r'stablehlo.sub.+\"MY_SUB\"', mlir_text)) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#include "torch_xla/csrc/runtime/xla_mlir_debuginfo_helper.h" | ||
|
||
#include <cstring> | ||
#include <string> | ||
#include <utility> | ||
|
||
#include "absl/log/log.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
namespace { | ||
|
||
// Defined in torch_xla/experimental/xla_mlir_debuginfo.py | ||
static constexpr char XLA_MLIR_DEBUGINFO_BEGIN[] = "<XLA_MLIR_DEBUGINFO_BEGIN>"; | ||
static constexpr char XLA_MLIR_DEBUGINFO_END[] = "<XLA_MLIR_DEBUGINFO_END>"; | ||
|
||
class PrepareXlaMlirDebuginfoPass : public mlir::OperationPass<mlir::ModuleOp> { | ||
public: | ||
explicit PrepareXlaMlirDebuginfoPass() | ||
: mlir::OperationPass<mlir::ModuleOp>::OperationPass( | ||
mlir::TypeID::get<PrepareXlaMlirDebuginfoPass>()) {} | ||
|
||
~PrepareXlaMlirDebuginfoPass() override = default; | ||
|
||
void runOnOperation() override { | ||
mlir::MLIRContext* context = &getContext(); | ||
getOperation().walk([&](mlir::Operation* op) { | ||
llvm::SmallVector<std::string> debuginfos; | ||
ExtractXlaMlirDebuginfo(op->getLoc(), debuginfos); | ||
|
||
if (!debuginfos.empty()) { | ||
// If multiple debuginfos are found (which should be an exception), | ||
// pick arbitrary one and discard the rest; | ||
const std::string& debuginfo = debuginfos[0]; | ||
op->setLoc( | ||
mlir::NameLoc::get(mlir::StringAttr::get(context, debuginfo))); | ||
} | ||
// TODO: Remove unspecified locations when a global flag is set. | ||
}); | ||
} | ||
|
||
void ExtractXlaMlirDebuginfo(mlir::Location loc, | ||
llvm::SmallVector<std::string>& debuginfos) { | ||
if (loc.isa<mlir::FusedLoc>()) { | ||
for (mlir::Location subloc : | ||
loc.dyn_cast<mlir::FusedLoc>().getLocations()) { | ||
ExtractXlaMlirDebuginfo(subloc, debuginfos); | ||
} | ||
} | ||
if (loc.isa<mlir::NameLoc>()) { | ||
std::string name(loc.dyn_cast<mlir::NameLoc>().getName().str()); | ||
|
||
for (size_t i = 0; i < name.size();) { | ||
size_t begin = name.find(XLA_MLIR_DEBUGINFO_BEGIN, i); | ||
if (begin == std::string::npos) { | ||
break; | ||
} | ||
begin += strlen(XLA_MLIR_DEBUGINFO_BEGIN); | ||
size_t end = name.find(XLA_MLIR_DEBUGINFO_END, begin); | ||
if (end == std::string::npos) { | ||
break; | ||
} | ||
|
||
std::string debuginfo = name.substr(begin, end - begin); | ||
debuginfos.push_back(std::move(debuginfo)); | ||
|
||
i = end + strlen(XLA_MLIR_DEBUGINFO_BEGIN); | ||
} | ||
} | ||
// TODO: Handle other loc types if debuginfo can be propagated/nested in | ||
// other loc type. | ||
} | ||
|
||
mlir::StringRef getName() const override { | ||
return llvm::getTypeName<PrepareXlaMlirDebuginfoPass>(); | ||
} | ||
|
||
std::unique_ptr<mlir::Pass> clonePass() const override { | ||
return std::make_unique<PrepareXlaMlirDebuginfoPass>(*this); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> | ||
CreatePrepareXlaMlirDebuginfoPass() { | ||
return std::make_unique<PrepareXlaMlirDebuginfoPass>(); | ||
} | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef XLA_MLIR_DEBUGINFO_HELPER_H_ | ||
#define XLA_MLIR_DEBUGINFO_HELPER_H_ | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace torch_xla { | ||
namespace runtime { | ||
|
||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> | ||
CreatePrepareXlaMlirDebuginfoPass(); | ||
|
||
} // namespace runtime | ||
} // namespace torch_xla | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
|
||
import torch | ||
import torch_xla | ||
from torch_xla.core import xla_model as xm | ||
from torch_xla.core.xla_model import XLA_LIB | ||
|
||
# Enable debug info automatically when importing this file. This is necessary | ||
# to propagate any debug info to downstream MLIR locations. | ||
os.environ["XLA_HLO_DEBUG"] = "1" | ||
xla_device = xm.xla_device() | ||
|
||
XLA_LIB.define("write_mlir_debuginfo(Tensor x, str data) -> Tensor") | ||
|
||
|
||
@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", "XLA") | ||
def write_mlir_debuginfo(x, data: str): | ||
begin_token = "<XLA_MLIR_DEBUGINFO_BEGIN>" | ||
end_token = "<XLA_MLIR_DEBUGINFO_END>" | ||
# Add the debuginfo string as the op prefix in MLIR location, surrounded | ||
# by begin and end tokens. The tokens and suffix op name will be removed | ||
# in the downstream pass PrepareXlaMlirDebuginfoPass after converting | ||
# HLO proto to MLIR. | ||
torch_xla._XLAC._set_xla_custom_op_name_prefix( | ||
x, | ||
begin_token + data + end_token, | ||
0, | ||
) | ||
return x | ||
|
||
|
||
@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", | ||
"CompositeExplicitAutograd") | ||
def write_mlir_debuginfo(x, data: str): | ||
return x | ||
|
||
|
||
@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", "Meta") | ||
def write_mlir_debuginfo_meta(x, data: str): | ||
return torch.empty_like(x) |