From f8d35ccd6e852a5a038e62ac41af6da234fbc9c8 Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Mon, 25 Mar 2024 11:59:49 -0700 Subject: [PATCH] Use XLA dispatch key --- torch_xla/experimental/xla_mlir_debuginfo.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torch_xla/experimental/xla_mlir_debuginfo.py b/torch_xla/experimental/xla_mlir_debuginfo.py index 732cd88ba11..4d12e20b2b2 100644 --- a/torch_xla/experimental/xla_mlir_debuginfo.py +++ b/torch_xla/experimental/xla_mlir_debuginfo.py @@ -13,12 +13,8 @@ XLA_LIB.define("write_mlir_debuginfo(Tensor x, str data) -> Tensor") -@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", - "CompositeExplicitAutograd") +@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", "XLA") def write_mlir_debuginfo(x, data: str): - if x.device != xla_device: - return x - begin_token = "" end_token = "" # Add the debuginfo string as the op prefix in MLIR location, surrounded @@ -33,6 +29,12 @@ def write_mlir_debuginfo(x, data: str): return x +@torch.library.impl(XLA_LIB, "write_mlir_debuginfo", + "CompositeExplicitAutograd") +def write_mlir_debuginfo(x, data: str): + return x + + @torch.library.impl(XLA_LIB, "write_mlir_debuginfo", "Meta") def write_mlir_debuginfo_meta(x, data: str): return torch.empty_like(x)