diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 85f9287b084..53467bd285f 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -350,6 +350,10 @@ def _exported_program_to_stablehlo_bundle(exported_model, if options.export_node_metadata: gm_serializer = GraphModuleSerializer(exported_model.graph_signature, exported_model.module_call_graph) + if "XLA_HLO_DEBUG" in os.environ: + xla_hlo_debug_env = os.environ["XLA_HLO_DEBUG"] + else: + xla_hlo_debug_env = "0" os.environ["XLA_HLO_DEBUG"] = "1" else: gm_serializer = None @@ -478,6 +482,8 @@ def _exported_program_to_stablehlo_bundle(exported_model, # Recover the global flag to not inline all scalars. torch_xla._XLAC._set_xla_all_numbers_special_scalars(False) + # Recover the global XLA_HLO_DEBUG flag + os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_env return bundle