Skip to content

Commit

Permalink
add node metadata export
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed May 13, 2024
1 parent 5cb473a commit f8b6f1a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
25 changes: 25 additions & 0 deletions test/stablehlo/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,31 @@ def forward(self, x):
shlo = exported_program_to_stablehlo(exported, options=export_options)
self.assertEqual(shlo._bundle.state_dict, {})

def test_export_node_metadata(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(in_features=4, out_features=16, bias=True)
self.fc2 = torch.nn.Linear(in_features=16, out_features=10, bias=True)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return torch.relu(x)

args = (torch.rand(2, 4),)
ep = torch.export.export(M(), args)
export_options = StableHLOExportOptions()
export_options.export_node_metadata = True
shlo = exported_program_to_stablehlo(ep, options=export_options)
shlo_text = shlo.get_stablehlo_text()
print(shlo_text)
self.assertTrue('stack_trace' in shlo_text)
self.assertTrue('nn_module_stack' in shlo_text)
self.assertTrue('source_fn_stack' in shlo_text)


if __name__ == '__main__':
unittest.main()
31 changes: 26 additions & 5 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch_xla
import torch_xla.experimental.quantized
from torch._decomp import get_decompositions
from torch._export.serde.serialize import GraphModuleSerializer
from torch.fx import _pytree as fx_pytree
from torch.utils import _pytree as pytree
from torch_xla.core import dynamo_bridge
Expand Down Expand Up @@ -64,6 +65,8 @@ class StableHLOExportOptions:
# Ops that will be mapped to stablehlo.custom_call in the
# exported StableHLO graph.
custom_ops_allowed_in_graph: Set[str] = field(default_factory=set)
# Export node metadata to NamedLoc in StableHLO.
export_node_metadata: bool = False


class StableHLOGraphModule:
Expand Down Expand Up @@ -219,11 +222,13 @@ class StableHLOModelBundle:

class XLAExportInterpreter(torch.fx.Interpreter):

def __init__(self, module, device, custom_ops_allowed_in_graph):
def __init__(self, module, device, custom_ops_allowed_in_graph,
gm_serializer):
self._device = device
super().__init__(module)
self.tensor_id_to_dynamic_dims = {}
self.custom_ops_allowed_in_graph = custom_ops_allowed_in_graph
self.gm_serializer = gm_serializer

def _mark_dynamic(self, tensor, dynamic_dims):
tid = torch_xla._XLAC._xla_get_tensor_id(tensor)
Expand Down Expand Up @@ -266,17 +271,25 @@ def run_node(self, n) -> Any:
dynamic_dims = [
i for i, x in enumerate(fake_t.shape) if not isinstance(x, int)
]
self._mark_dynamic(res, dynamic_dims)
return res
if n.op == 'call_function':
elif n.op == 'call_function':
if hasattr(n.target, 'namespace'
) and n.target.namespace in self.custom_ops_allowed_in_graph:
output_shapes, output_dtypes = extract_custom_call_outputs_shape_dtype(
n)
call_name = str(n.target)
n.target = stablehlo_custom_call
n.args = (n.args, call_name, output_shapes, output_dtypes)
return super().run_node(n)
res = super().run_node(n)
else:
res = super().run_node(n)
if self.gm_serializer is not None:
node_metadata = json.dumps(self.gm_serializer.serialize_metadata(n))
pytree.tree_map_only(
torch.Tensor,
lambda x: torch_xla._XLAC._set_xla_custom_op_name_prefix(
x, node_metadata, 1), res)
return res


def _extract_input_args(exported_model, options):
Expand Down Expand Up @@ -334,8 +347,16 @@ def _exported_program_to_stablehlo_bundle(exported_model,
if options.inline_all_constant:
# Inline all constants.
torch_xla._XLAC._set_xla_all_numbers_special_scalars(True)
if options.export_node_metadata:
gm_serializer = GraphModuleSerializer(exported_model.graph_signature,
exported_model.module_call_graph)
os.environ["XLA_HLO_DEBUG"] = "1"
else:
gm_serializer = None

xla_interpreter = XLAExportInterpreter(exported_model.graph_module, device,
options.custom_ops_allowed_in_graph)
options.custom_ops_allowed_in_graph,
gm_serializer)
with torch.no_grad():
res = xla_interpreter.run(*_flat_input_args, enable_io_processing=False)
res = res[num_mutations:]
Expand Down

0 comments on commit f8b6f1a

Please sign in to comment.