Skip to content

Commit

Permalink
Add missing shape metadata for the last replaced fx node (#6805)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 authored Mar 26, 2024
1 parent 134f5b6 commit 899a0fa
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 28 deletions.
63 changes: 63 additions & 0 deletions test/stablehlo/export_tinyroberta_unbounded_dynamism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import torch_xla
from torch.utils import _pytree as pytree
from torch.export import Dim, export
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.tf_saved_model_integration import \
save_torch_module_as_tf_saved_model
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"


class WrapModel(torch.nn.Module):

def __init__(self):
super().__init__()
self._model = AutoModelForQuestionAnswering.from_pretrained(
"deepset/tinyroberta-squad2")

def forward(self, input, mask):
res = self._model.forward(input, mask)
# return tuple(
# x for x in (res.loss, res.start_logits, res.end_logits,
# res.hidden_states) if x is not None)
return res.start_logits, res.end_logits


def _get_fake_pipeline_model_inputs():
tokens_len = 10
input_ids = torch.randint(
low=0, high=2000, size=(3, tokens_len), dtype=torch.int64)
attention_mask = torch.ones((3, tokens_len), dtype=torch.int64)
return (input_ids, attention_mask)


model = WrapModel()
args = _get_fake_pipeline_model_inputs()
dynamic_shapes = ({0: Dim("bs")}, {0: Dim("bs")})
# dynamic_shapes = None
ep = export(model, args=args, dynamic_shapes=dynamic_shapes)

tmp_dir = "/tmp/tiny_roberta/tiny_roberta_export"
save_torch_module_as_tf_saved_model(
model, args, tmp_dir, dynamic_shapes=dynamic_shapes)

tokens_len = 10
args = (torch.randint(
low=0, high=2000, size=(2, tokens_len),
dtype=torch.int64), torch.ones((2, tokens_len), dtype=torch.int64))
loaded_m = tf.saved_model.load(tmp_dir)
tf_input = pytree.tree_map_only(torch.Tensor, lambda x: tf.constant(x.numpy()),
args)

tf_output = loaded_m.f(*tf_input)
with torch.no_grad():
torch_output = model(*args)
print(np.max(torch_output[0].numpy() - tf_output[0].numpy()))
print(np.max(torch_output[1].numpy() - tf_output[1].numpy()))
4 changes: 2 additions & 2 deletions test/stablehlo/test_export_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def test_embedding_indices_flatten(self):
flatten_embedding_indices_tensor(ep.graph_module)
ep.graph_module.recompile()
print(ep)
self.assertTrue('aten.view' in ep.graph_module.code)
self.assertTrue('aten.reshape' in ep.graph_module.code)
replace_dynamic_view_with_xla_op(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('aten.view' not in ep.graph_module.code)
self.assertTrue('aten.reshape' not in ep.graph_module.code)
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))
Expand Down
66 changes: 42 additions & 24 deletions torch_xla/experimental/unbounded_dynamism_export.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import operator
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import torch
import torch_xla.experimental.xla_dynamic_reshape_ops
from torch._inductor.fx_utils import get_fake, get_fake_args_kwargs
from torch.export import export
from torch.fx import Graph, GraphModule, subgraph_rewriter
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_map

aten = torch.ops.aten


def wrap_func_as_nn_module(f):
def wrap_func_as_nn_module(f: Callable[..., Any]):

class M(torch.nn.Module):

Expand All @@ -16,6 +22,18 @@ def forward(self, *args):
return M()


def call_function(
graph: torch.fx.Graph,
target: Callable[..., Any],
args: Optional[Tuple[torch.fx.node.Argument, ...]] = None,
kwargs: Optional[Dict[str, torch.fx.node.Argument]] = None,
) -> torch.fx.Node:
node = graph.call_function(target, args, kwargs)
_, args, kwargs = get_fake_args_kwargs(node)
node.meta["val"] = target(*args, **kwargs)
return node


def native_layer_norm_impl(input, normalized_shape, weight, bias, eps):
mean = torch.mean(input, -1, keepdim=True)
rstd = torch.rsqrt(torch.var(input, -1, keepdim=True, correction=0) + eps)
Expand Down Expand Up @@ -52,7 +70,7 @@ def native_group_norm_pattern(x, weight, bias, N, C, HxW, group, eps):
group, eps)[0]


def decompose_dynamic_native_layer_norm(gm):
def decompose_dynamic_native_layer_norm(gm: GraphModule):
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
gm,
native_layer_norm_pattern,
Expand All @@ -65,7 +83,7 @@ def decompose_dynamic_native_layer_norm(gm):
assert len(match_n) == 1


def decompose_dynamic_native_group_norm(gm):
def decompose_dynamic_native_group_norm(gm: GraphModule):
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
gm,
native_group_norm_pattern,
Expand Down Expand Up @@ -99,7 +117,7 @@ def decompose_dynamic_shape_select(gm: GraphModule):
select_idx = n.args[2]
slice_args = (select_src_node, select_dim, select_idx,
(select_idx + 1), 1)
slice_node = graph.call_function(aten.slice, slice_args)
slice_node = call_function(graph, aten.slice, slice_args)
view_new_shape = []
for dim, size in enumerate(select_src_shape):
if dim == select_dim:
Expand All @@ -108,11 +126,11 @@ def decompose_dynamic_shape_select(gm: GraphModule):
view_new_shape.append(size)
else:
get_dim_size_args = (select_src_node, dim)
get_dim_size_node = graph.call_function(aten.sym_size.int,
get_dim_size_args)
get_dim_size_node = call_function(graph, aten.sym_size.int,
get_dim_size_args)
view_new_shape.append(get_dim_size_node)
view_args = (slice_node, view_new_shape)
view_node = graph.call_function(aten.view.default, view_args)
view_node = call_function(graph, aten.view.default, view_args)
n.replace_all_uses_with(view_node)
graph.erase_node(n)

Expand Down Expand Up @@ -143,8 +161,8 @@ def decompose_split_with_sizes(gm: GraphModule):
decomposed_slices = []
for size in split_sizes:
slice_args = (src_node, split_dim, start_idx, start_idx + size)
slice_node = graph.call_function(torch.ops.aten.slice.Tensor,
slice_args)
slice_node = call_function(graph, torch.ops.aten.slice.Tensor,
slice_args)
start_idx += size
decomposed_slices.append(slice_node)
consumers = n.users
Expand Down Expand Up @@ -181,8 +199,8 @@ def flatten_embedding_indices_tensor(gm: GraphModule):
for dim, size in enumerate(indices_shape):
if not isinstance(size, int):
get_dim_size_args = (indices_node, dim)
get_dim_size_node = graph.call_function(aten.sym_size.int,
get_dim_size_args)
get_dim_size_node = call_function(graph, aten.sym_size.int,
get_dim_size_args)
recover_view_shape.append(get_dim_size_node)
else:
flatten_mul_scale *= size
Expand All @@ -191,21 +209,22 @@ def flatten_embedding_indices_tensor(gm: GraphModule):
recover_view_shape.append(weight_shape[-1])

mul_args = (get_dim_size_node, flatten_mul_scale)
flatten_size_node = graph.call_function(aten.mul.Scalar, mul_args)
flatten_size_node = call_function(graph, operator.mul, mul_args)
view_args = (indices_node, [flatten_size_node])
view_node = graph.call_function(aten.view, view_args)
view_node = call_function(graph, aten.reshape, view_args)
new_embedding_args = n.args[0:1] + (view_node,)
if len(n.args) > 2:
new_embedding_args += n.args[2:]
n.args = new_embedding_args
with graph.inserting_after(n):
recover_view_args = (n, recover_view_shape)
recover_view_node = graph.call_function(aten.view, recover_view_args)
recover_view_node = call_function(graph, aten.reshape,
recover_view_args)
n.replace_all_uses_with(recover_view_node)
recover_view_node.update_arg(0, n)


def _is_no_op_slice(n):
def _is_no_op_slice(n: torch.fx.Node):
assert n.op == "call_function" and n.target == aten.slice.Tensor
return n.args[2] == 0 and n.args[3] == torch.iinfo(torch.int64).max

Expand Down Expand Up @@ -267,6 +286,8 @@ def replace_dynamic_view_with_xla_op(gm: GraphModule):
'''
g = gm.graph
ATEN_VIEW_OPS = [
torch.ops.aten.reshape,
torch.ops.aten.reshape.default,
torch.ops.aten.view,
torch.ops.aten.view.default,
torch.ops.aten._unsafe_view,
Expand All @@ -290,10 +311,7 @@ def replace_dynamic_view_with_xla_op(gm: GraphModule):
mul_scaler = 1
sym_size_node = dynamic_src
mul_node = None
if hasattr(
dynamic_src.target,
"__name__") and (dynamic_src.target.__name__ == "mul" or
dynamic_src.target.__name__ == "mul.Scalar"):
if dynamic_src.target == operator.mul:
assert isinstance(dynamic_src.args[0], int) or isinstance(
dynamic_src.args[1], int)
mul_node = dynamic_src
Expand Down Expand Up @@ -326,22 +344,22 @@ def dynamic_unsqueeze_to_view(gm: GraphModule):
]
if len(symbolic_dims) == 0:
continue
assert len(symbolic_dims) == 1, "Only 1 dimention can be symbolic."
assert len(symbolic_dims) == 1, "Only 1 dimension can be symbolic."
view_args = list(src_shape)
with graph.inserting_before(n):
for dim in symbolic_dims:
get_size_node = graph.call_function(aten.sym_size.int,
(unsqueeze_src, dim))
get_size_node = call_function(graph, aten.sym_size.int,
(unsqueeze_src, dim))
view_args[dim] = get_size_node
squeezed_dim = n.args[1]
view_args = (unsqueeze_src,
view_args[:squeezed_dim] + [1] + view_args[squeezed_dim:])
view_node = graph.call_function(aten.view, view_args)
view_node = call_function(graph, aten.view, view_args)
n.replace_all_uses_with(view_node)
graph.erase_node(n)


def exported_program_has_symbolic_input_shape(ep):
def exported_program_has_symbolic_input_shape(ep: torch.export.ExportedProgram):
for n in ep.graph_module.graph.nodes:
if n.op == "placeholder":
fake_t = n.meta['val']
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/experimental/xla_dynamic_reshape_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch_xla.core.xla_model import XLA_LIB

XLA_LIB.define(
"dynamic_expand(Tensor input, int[] size, Tensor src_tensor, int src_dim, int target_dim) -> Tensor"
"dynamic_expand(Tensor input, SymInt[] size, Tensor src_tensor, int src_dim, int target_dim) -> Tensor"
)


Expand Down Expand Up @@ -61,7 +61,7 @@ def dynamic_expand_meta(


XLA_LIB.define(
"dynamic_view(Tensor input, int[] size, Tensor src_tensor, int src_dim, int target_dim, float mul_scaler) -> Tensor"
"dynamic_view(Tensor input, SymInt[] size, Tensor src_tensor, int src_dim, int target_dim, float mul_scaler) -> Tensor"
)


Expand Down

0 comments on commit 899a0fa

Please sign in to comment.