From bbbe01d92c7a7b3866c8dfe8e03a9eed357daae4 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Fri, 22 Mar 2024 06:24:32 +0000 Subject: [PATCH] add missing shape for replaced fx node in fx passes --- .../export_tinyroberta_unbounded_dynamism.py | 61 +++++++++++++++++++ .../experimental/unbounded_dynamism_export.py | 4 ++ 2 files changed, 65 insertions(+) create mode 100644 test/stablehlo/export_tinyroberta_unbounded_dynamism.py diff --git a/test/stablehlo/export_tinyroberta_unbounded_dynamism.py b/test/stablehlo/export_tinyroberta_unbounded_dynamism.py new file mode 100644 index 000000000000..fd4a44c796e2 --- /dev/null +++ b/test/stablehlo/export_tinyroberta_unbounded_dynamism.py @@ -0,0 +1,61 @@ +import os + +import numpy as np +import tensorflow as tf +import torch +import torch.nn as nn +import torch_xla +from torch.utils import _pytree as pytree +from torch.export import Dim, export +from torch_xla.stablehlo import exported_program_to_stablehlo +from torch_xla.tf_saved_model_integration import \ + save_torch_module_as_tf_saved_model +from transformers import AutoModelForQuestionAnswering, AutoTokenizer + +os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1" + + +class WrapModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self._model = AutoModelForQuestionAnswering.from_pretrained( + "deepset/tinyroberta-squad2") + + def forward(self, input, mask): + res = self._model.forward(input, mask) + return tuple( + x for x in (res.loss, res.start_logits, res.end_logits, + res.hidden_states) if x is not None) + + +def _get_fake_pipeline_model_inputs(): + tokens_len = 10 + input_ids = torch.randint( + low=0, high=2000, size=(3, tokens_len), dtype=torch.int64) + attention_mask = torch.ones((3, tokens_len), dtype=torch.int64) + return (input_ids, attention_mask) + + +model = WrapModel() +args = _get_fake_pipeline_model_inputs() +dynamic_shapes = ({0: Dim("bs")}, {0: Dim("bs")}) +ep = export(model, args=args, dynamic_shapes=dynamic_shapes) + +tmp_dir = "/tmp/tiny_roberta/tiny_roberta_export" +save_torch_module_as_tf_saved_model( + model, args, tmp_dir, dynamic_shapes=dynamic_shapes) + +tokens_len = 10 +args = (torch.randint( + low=0, high=2000, size=(2, tokens_len), + dtype=torch.int64), torch.ones((2, tokens_len), dtype=torch.int64)) +loaded_m = tf.saved_model.load(tmp_dir) +tf_input = pytree.tree_map_only(torch.Tensor, lambda x: tf.constant(x.numpy()), + args) + +tf_output = loaded_m.f(*tf_input) +with torch.no_grad(): + torch_output = model(*args) + print(np.max(torch_output[0].numpy() - tf_output[0].numpy())) + print(np.max(torch_output[1].numpy() - tf_output[1].numpy())) diff --git a/torch_xla/experimental/unbounded_dynamism_export.py b/torch_xla/experimental/unbounded_dynamism_export.py index 042ba98d2c09..d517e79276c9 100644 --- a/torch_xla/experimental/unbounded_dynamism_export.py +++ b/torch_xla/experimental/unbounded_dynamism_export.py @@ -113,6 +113,7 @@ def decompose_dynamic_shape_select(gm: GraphModule): view_new_shape.append(get_dim_size_node) view_args = (slice_node, view_new_shape) view_node = graph.call_function(aten.view.default, view_args) + view_node.meta['val'] = n.meta['val'] n.replace_all_uses_with(view_node) graph.erase_node(n) @@ -151,6 +152,7 @@ def decompose_split_with_sizes(gm: GraphModule): for consumer in consumers: assert n.op == "call_function" and consumer.target.__name__ == "getitem" slice_idx = consumer.args[1] + decomposed_slices[slice_idx].meta['val'] = consumer.meta['val'] consumer.replace_all_uses_with(decomposed_slices[slice_idx]) @@ -201,6 +203,7 @@ def flatten_embedding_indices_tensor(gm: GraphModule): with graph.inserting_after(n): recover_view_args = (n, recover_view_shape) recover_view_node = graph.call_function(aten.view, recover_view_args) + recover_view_node.meta['val'] = n.meta['val'] n.replace_all_uses_with(recover_view_node) recover_view_node.update_arg(0, n) @@ -337,6 +340,7 @@ def dynamic_unsqueeze_to_view(gm: GraphModule): view_args = (unsqueeze_src, view_args[:squeezed_dim] + [1] + view_args[squeezed_dim:]) view_node = graph.call_function(aten.view, view_args) + view_node.meta['val'] = n.meta['val'] n.replace_all_uses_with(view_node) graph.erase_node(n)