Skip to content

Commit

Permalink
use write mlir info api
Browse files Browse the repository at this point in the history
fix flag

optimize

fix dynamic
  • Loading branch information
Siyuan Liu committed May 13, 2024
1 parent 33aa818 commit 7a0cea0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
1 change: 0 additions & 1 deletion test/stablehlo/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def forward(self, x):
export_options.export_node_metadata = True
shlo = exported_program_to_stablehlo(ep, options=export_options)
shlo_text = shlo.get_stablehlo_text()
print(shlo_text)
self.assertTrue('stack_trace' in shlo_text)
self.assertTrue('nn_module_stack' in shlo_text)
self.assertTrue('source_fn_stack' in shlo_text)
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/xla_mlir_debuginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def write_mlir_debuginfo(x, data: str):

@torch.library.impl(XLA_LIB, "write_mlir_debuginfo",
"CompositeExplicitAutograd")
def write_mlir_debuginfo(x, data: str):
def write_mlir_debuginfo_tensor(x, data: str):
return x


Expand Down
15 changes: 6 additions & 9 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch_xla.experimental.unbounded_dynamism_export import (
exported_program_has_symbolic_input_shape,
process_exported_program_with_symbolic_input)
from torch_xla.experimental.xla_mlir_debuginfo import write_mlir_debuginfo


def _get_numpy_dtype(dtype):
Expand Down Expand Up @@ -271,7 +272,7 @@ def run_node(self, n) -> Any:
dynamic_dims = [
i for i, x in enumerate(fake_t.shape) if not isinstance(x, int)
]
return res
self._mark_dynamic(res, dynamic_dims)
elif n.op == 'call_function':
if hasattr(n.target, 'namespace'
) and n.target.namespace in self.custom_ops_allowed_in_graph:
Expand All @@ -285,10 +286,9 @@ def run_node(self, n) -> Any:
res = super().run_node(n)
if self.gm_serializer is not None:
node_metadata = json.dumps(self.gm_serializer.serialize_metadata(n))
pytree.tree_map_only(
torch.Tensor,
lambda x: torch_xla._XLAC._set_xla_custom_op_name_prefix(
x, node_metadata, 1), res)
pytree.tree_map_only(torch.Tensor,
lambda x: write_mlir_debuginfo(x, node_metadata),
res)
return res


Expand Down Expand Up @@ -347,13 +347,10 @@ def _exported_program_to_stablehlo_bundle(exported_model,
if options.inline_all_constant:
# Inline all constants.
torch_xla._XLAC._set_xla_all_numbers_special_scalars(True)
xla_hlo_debug_env = os.environ.get("XLA_HLO_DEBUG", "0")
if options.export_node_metadata:
gm_serializer = GraphModuleSerializer(exported_model.graph_signature,
exported_model.module_call_graph)
if "XLA_HLO_DEBUG" in os.environ:
xla_hlo_debug_env = os.environ["XLA_HLO_DEBUG"]
else:
xla_hlo_debug_env = "0"
os.environ["XLA_HLO_DEBUG"] = "1"
else:
gm_serializer = None
Expand Down

0 comments on commit 7a0cea0

Please sign in to comment.