From f8b6f1a8cbcbbac9adca7383a29cba75dffb513d Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 10 May 2024 22:07:39 +0000 Subject: [PATCH] add node metadata export --- test/stablehlo/test_exports.py | 25 +++++++++++++++++++++++++ torch_xla/stablehlo.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/test/stablehlo/test_exports.py b/test/stablehlo/test_exports.py index 6208ae1ca52..13d0246a88d 100644 --- a/test/stablehlo/test_exports.py +++ b/test/stablehlo/test_exports.py @@ -133,6 +133,31 @@ def forward(self, x): shlo = exported_program_to_stablehlo(exported, options=export_options) self.assertEqual(shlo._bundle.state_dict, {}) + def test_export_node_metadata(self): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(in_features=4, out_features=16, bias=True) + self.fc2 = torch.nn.Linear(in_features=16, out_features=10, bias=True) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return torch.relu(x) + + args = (torch.rand(2, 4),) + ep = torch.export.export(M(), args) + export_options = StableHLOExportOptions() + export_options.export_node_metadata = True + shlo = exported_program_to_stablehlo(ep, options=export_options) + shlo_text = shlo.get_stablehlo_text() + print(shlo_text) + self.assertTrue('stack_trace' in shlo_text) + self.assertTrue('nn_module_stack' in shlo_text) + self.assertTrue('source_fn_stack' in shlo_text) + if __name__ == '__main__': unittest.main() diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 103cb7161be..85f9287b084 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -11,6 +11,7 @@ import torch_xla import torch_xla.experimental.quantized from torch._decomp import get_decompositions +from torch._export.serde.serialize import GraphModuleSerializer from torch.fx import _pytree as fx_pytree from torch.utils import _pytree as pytree from torch_xla.core import dynamo_bridge @@ -64,6 +65,8 @@ class StableHLOExportOptions: # Ops that will be mapped to stablehlo.custom_call in the # exported StableHLO graph. custom_ops_allowed_in_graph: Set[str] = field(default_factory=set) + # Export node metadata to NamedLoc in StableHLO. + export_node_metadata: bool = False class StableHLOGraphModule: @@ -219,11 +222,13 @@ class StableHLOModelBundle: class XLAExportInterpreter(torch.fx.Interpreter): - def __init__(self, module, device, custom_ops_allowed_in_graph): + def __init__(self, module, device, custom_ops_allowed_in_graph, + gm_serializer): self._device = device super().__init__(module) self.tensor_id_to_dynamic_dims = {} self.custom_ops_allowed_in_graph = custom_ops_allowed_in_graph + self.gm_serializer = gm_serializer def _mark_dynamic(self, tensor, dynamic_dims): tid = torch_xla._XLAC._xla_get_tensor_id(tensor) @@ -266,9 +271,8 @@ def run_node(self, n) -> Any: dynamic_dims = [ i for i, x in enumerate(fake_t.shape) if not isinstance(x, int) ] - self._mark_dynamic(res, dynamic_dims) return res - if n.op == 'call_function': + elif n.op == 'call_function': if hasattr(n.target, 'namespace' ) and n.target.namespace in self.custom_ops_allowed_in_graph: output_shapes, output_dtypes = extract_custom_call_outputs_shape_dtype( @@ -276,7 +280,16 @@ def run_node(self, n) -> Any: call_name = str(n.target) n.target = stablehlo_custom_call n.args = (n.args, call_name, output_shapes, output_dtypes) - return super().run_node(n) + res = super().run_node(n) + else: + res = super().run_node(n) + if self.gm_serializer is not None: + node_metadata = json.dumps(self.gm_serializer.serialize_metadata(n)) + pytree.tree_map_only( + torch.Tensor, + lambda x: torch_xla._XLAC._set_xla_custom_op_name_prefix( + x, node_metadata, 1), res) + return res def _extract_input_args(exported_model, options): @@ -334,8 +347,16 @@ def _exported_program_to_stablehlo_bundle(exported_model, if options.inline_all_constant: # Inline all constants. torch_xla._XLAC._set_xla_all_numbers_special_scalars(True) + if options.export_node_metadata: + gm_serializer = GraphModuleSerializer(exported_model.graph_signature, + exported_model.module_call_graph) + os.environ["XLA_HLO_DEBUG"] = "1" + else: + gm_serializer = None + xla_interpreter = XLAExportInterpreter(exported_model.graph_module, device, - options.custom_ops_allowed_in_graph) + options.custom_ops_allowed_in_graph, + gm_serializer) with torch.no_grad(): res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False) res = res[num_mutations:]