Skip to content

Commit

Permalink
Update xla_mlir_debuginfo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc committed Mar 21, 2024
1 parent 1e825bd commit 0150e28
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torch_xla/experimental/xla_mlir_debuginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,23 @@

import torch
import torch_xla
from torch_xla.core import xla_model as xm
from torch_xla.core.xla_model import XLA_LIB

# Enable debug info automatically when importing this file. This is necessary
# to propagate any debug info to downstream MLIR locations.
os.environ["XLA_HLO_DEBUG"] = "1"
xla_device = xm.xla_device()

XLA_LIB.define("write_mlir_debuginfo(Tensor x, str data) -> Tensor")


@torch.library.impl(XLA_LIB, "write_mlir_debuginfo",
"CompositeExplicitAutograd")
def write_mlir_debuginfo(x, data: str):
if x.device != xla_device:
return x

begin_token = "<XLA_MLIR_DEBUGINFO_BEGIN>"
end_token = "<XLA_MLIR_DEBUGINFO_END>"
# Add the debuginfo string as the op prefix in MLIR location, surrounded
Expand Down

0 comments on commit 0150e28

Please sign in to comment.