Skip to content

Commit

Permalink
add missing shape for replaced fx node in fx passes
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed Mar 22, 2024
1 parent 782f05d commit bbbe01d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
61 changes: 61 additions & 0 deletions test/stablehlo/export_tinyroberta_unbounded_dynamism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os

import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import torch_xla
from torch.utils import _pytree as pytree
from torch.export import Dim, export
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.tf_saved_model_integration import \
save_torch_module_as_tf_saved_model
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"


class WrapModel(torch.nn.Module):

def __init__(self):
super().__init__()
self._model = AutoModelForQuestionAnswering.from_pretrained(
"deepset/tinyroberta-squad2")

def forward(self, input, mask):
res = self._model.forward(input, mask)
return tuple(
x for x in (res.loss, res.start_logits, res.end_logits,
res.hidden_states) if x is not None)


def _get_fake_pipeline_model_inputs():
tokens_len = 10
input_ids = torch.randint(
low=0, high=2000, size=(3, tokens_len), dtype=torch.int64)
attention_mask = torch.ones((3, tokens_len), dtype=torch.int64)
return (input_ids, attention_mask)


model = WrapModel()
args = _get_fake_pipeline_model_inputs()
dynamic_shapes = ({0: Dim("bs")}, {0: Dim("bs")})
ep = export(model, args=args, dynamic_shapes=dynamic_shapes)

tmp_dir = "/tmp/tiny_roberta/tiny_roberta_export"
save_torch_module_as_tf_saved_model(
model, args, tmp_dir, dynamic_shapes=dynamic_shapes)

tokens_len = 10
args = (torch.randint(
low=0, high=2000, size=(2, tokens_len),
dtype=torch.int64), torch.ones((2, tokens_len), dtype=torch.int64))
loaded_m = tf.saved_model.load(tmp_dir)
tf_input = pytree.tree_map_only(torch.Tensor, lambda x: tf.constant(x.numpy()),
args)

tf_output = loaded_m.f(*tf_input)
with torch.no_grad():
torch_output = model(*args)
print(np.max(torch_output[0].numpy() - tf_output[0].numpy()))
print(np.max(torch_output[1].numpy() - tf_output[1].numpy()))
4 changes: 4 additions & 0 deletions torch_xla/experimental/unbounded_dynamism_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def decompose_dynamic_shape_select(gm: GraphModule):
view_new_shape.append(get_dim_size_node)
view_args = (slice_node, view_new_shape)
view_node = graph.call_function(aten.view.default, view_args)
view_node.meta['val'] = n.meta['val']
n.replace_all_uses_with(view_node)
graph.erase_node(n)

Expand Down Expand Up @@ -151,6 +152,7 @@ def decompose_split_with_sizes(gm: GraphModule):
for consumer in consumers:
assert n.op == "call_function" and consumer.target.__name__ == "getitem"
slice_idx = consumer.args[1]
decomposed_slices[slice_idx].meta['val'] = consumer.meta['val']
consumer.replace_all_uses_with(decomposed_slices[slice_idx])


Expand Down Expand Up @@ -201,6 +203,7 @@ def flatten_embedding_indices_tensor(gm: GraphModule):
with graph.inserting_after(n):
recover_view_args = (n, recover_view_shape)
recover_view_node = graph.call_function(aten.view, recover_view_args)
recover_view_node.meta['val'] = n.meta['val']
n.replace_all_uses_with(recover_view_node)
recover_view_node.update_arg(0, n)

Expand Down Expand Up @@ -337,6 +340,7 @@ def dynamic_unsqueeze_to_view(gm: GraphModule):
view_args = (unsqueeze_src,
view_args[:squeezed_dim] + [1] + view_args[squeezed_dim:])
view_node = graph.call_function(aten.view, view_args)
view_node.meta['val'] = n.meta['val']
n.replace_all_uses_with(view_node)
graph.erase_node(n)

Expand Down

0 comments on commit bbbe01d

Please sign in to comment.