Skip to content

Commit

Permalink
Use XLA dispatch key
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc committed Mar 25, 2024
1 parent 730c5ec commit f8d35cc
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions torch_xla/experimental/xla_mlir_debuginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
XLA_LIB.define("write_mlir_debuginfo(Tensor x, str data) -> Tensor")


@torch.library.impl(XLA_LIB, "write_mlir_debuginfo",
"CompositeExplicitAutograd")
@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", "XLA")
def write_mlir_debuginfo(x, data: str):
if x.device != xla_device:
return x

begin_token = "<XLA_MLIR_DEBUGINFO_BEGIN>"
end_token = "<XLA_MLIR_DEBUGINFO_END>"
# Add the debuginfo string as the op prefix in MLIR location, surrounded
Expand All @@ -33,6 +29,12 @@ def write_mlir_debuginfo(x, data: str):
return x


@torch.library.impl(XLA_LIB, "write_mlir_debuginfo",
"CompositeExplicitAutograd")
def write_mlir_debuginfo(x, data: str):
return x


@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", "Meta")
def write_mlir_debuginfo_meta(x, data: str):
return torch.empty_like(x)

0 comments on commit f8d35cc

Please sign in to comment.