From 681ff00a467e14ab9398120c8d139b1eb73bceac Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 25 Mar 2024 21:24:12 +0000 Subject: [PATCH] add shape inference after new node is inserted --- .../export_tinyroberta_unbounded_dynamism.py | 63 ++++++++++++++++++ test/stablehlo/test_export_fx_passes.py | 4 +- .../experimental/unbounded_dynamism_export.py | 66 ++++++++++++------- .../experimental/xla_dynamic_reshape_ops.py | 4 +- 4 files changed, 109 insertions(+), 28 deletions(-) create mode 100644 test/stablehlo/export_tinyroberta_unbounded_dynamism.py diff --git a/test/stablehlo/export_tinyroberta_unbounded_dynamism.py b/test/stablehlo/export_tinyroberta_unbounded_dynamism.py new file mode 100644 index 00000000000..47e9cd2f78b --- /dev/null +++ b/test/stablehlo/export_tinyroberta_unbounded_dynamism.py @@ -0,0 +1,63 @@ +import os + +import numpy as np +import tensorflow as tf +import torch +import torch.nn as nn +import torch_xla +from torch.utils import _pytree as pytree +from torch.export import Dim, export +from torch_xla.stablehlo import exported_program_to_stablehlo +from torch_xla.tf_saved_model_integration import \ + save_torch_module_as_tf_saved_model +from transformers import AutoModelForQuestionAnswering, AutoTokenizer + +os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1" + + +class WrapModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self._model = AutoModelForQuestionAnswering.from_pretrained( + "deepset/tinyroberta-squad2") + + def forward(self, input, mask): + res = self._model.forward(input, mask) + # return tuple( + # x for x in (res.loss, res.start_logits, res.end_logits, + # res.hidden_states) if x is not None) + return res.start_logits, res.end_logits + + +def _get_fake_pipeline_model_inputs(): + tokens_len = 10 + input_ids = torch.randint( + low=0, high=2000, size=(3, tokens_len), dtype=torch.int64) + attention_mask = torch.ones((3, tokens_len), dtype=torch.int64) + return (input_ids, attention_mask) + + +model = WrapModel() +args = _get_fake_pipeline_model_inputs() +dynamic_shapes = ({0: Dim("bs")}, {0: Dim("bs")}) +# dynamic_shapes = None +ep = export(model, args=args, dynamic_shapes=dynamic_shapes) + +tmp_dir = "/tmp/tiny_roberta/tiny_roberta_export" +save_torch_module_as_tf_saved_model( + model, args, tmp_dir, dynamic_shapes=dynamic_shapes) + +tokens_len = 10 +args = (torch.randint( + low=0, high=2000, size=(2, tokens_len), + dtype=torch.int64), torch.ones((2, tokens_len), dtype=torch.int64)) +loaded_m = tf.saved_model.load(tmp_dir) +tf_input = pytree.tree_map_only(torch.Tensor, lambda x: tf.constant(x.numpy()), + args) + +tf_output = loaded_m.f(*tf_input) +with torch.no_grad(): + torch_output = model(*args) + print(np.max(torch_output[0].numpy() - tf_output[0].numpy())) + print(np.max(torch_output[1].numpy() - tf_output[1].numpy())) diff --git a/test/stablehlo/test_export_fx_passes.py b/test/stablehlo/test_export_fx_passes.py index 9827bc56188..d1e731abd6e 100644 --- a/test/stablehlo/test_export_fx_passes.py +++ b/test/stablehlo/test_export_fx_passes.py @@ -63,10 +63,10 @@ def test_embedding_indices_flatten(self): flatten_embedding_indices_tensor(ep.graph_module) ep.graph_module.recompile() print(ep) - self.assertTrue('aten.view' in ep.graph_module.code) + self.assertTrue('aten.reshape' in ep.graph_module.code) replace_dynamic_view_with_xla_op(ep.graph_module) ep.graph_module.recompile() - self.assertTrue('aten.view' not in ep.graph_module.code) + self.assertTrue('aten.reshape' not in ep.graph_module.code) self.assertTrue('xla.dynamic_view' in ep.graph_module.code) out2 = ep.module()(*args) self.assertTrue(torch.allclose(out1, out2)) diff --git a/torch_xla/experimental/unbounded_dynamism_export.py b/torch_xla/experimental/unbounded_dynamism_export.py index 042ba98d2c0..b5f0637fd12 100644 --- a/torch_xla/experimental/unbounded_dynamism_export.py +++ b/torch_xla/experimental/unbounded_dynamism_export.py @@ -1,12 +1,18 @@ +import operator +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union + import torch import torch_xla.experimental.xla_dynamic_reshape_ops +from torch._inductor.fx_utils import get_fake, get_fake_args_kwargs from torch.export import export from torch.fx import Graph, GraphModule, subgraph_rewriter +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map aten = torch.ops.aten -def wrap_func_as_nn_module(f): +def wrap_func_as_nn_module(f: Callable[..., Any]): class M(torch.nn.Module): @@ -16,6 +22,18 @@ def forward(self, *args): return M() +def call_function( + graph: torch.fx.Graph, + target: Callable[..., Any], + args: Optional[Tuple[torch.fx.node.Argument, ...]] = None, + kwargs: Optional[Dict[str, torch.fx.node.Argument]] = None, +) -> torch.fx.Node: + node = graph.call_function(target, args, kwargs) + _, args, kwargs = get_fake_args_kwargs(node) + node.meta["val"] = target(*args, **kwargs) + return node + + def native_layer_norm_impl(input, normalized_shape, weight, bias, eps): mean = torch.mean(input, -1, keepdim=True) rstd = torch.rsqrt(torch.var(input, -1, keepdim=True, correction=0) + eps) @@ -52,7 +70,7 @@ def native_group_norm_pattern(x, weight, bias, N, C, HxW, group, eps): group, eps)[0] -def decompose_dynamic_native_layer_norm(gm): +def decompose_dynamic_native_layer_norm(gm: GraphModule): replaced_patterns = subgraph_rewriter.replace_pattern_with_filters( gm, native_layer_norm_pattern, @@ -65,7 +83,7 @@ def decompose_dynamic_native_layer_norm(gm): assert len(match_n) == 1 -def decompose_dynamic_native_group_norm(gm): +def decompose_dynamic_native_group_norm(gm: GraphModule): replaced_patterns = subgraph_rewriter.replace_pattern_with_filters( gm, native_group_norm_pattern, @@ -99,7 +117,7 @@ def decompose_dynamic_shape_select(gm: GraphModule): select_idx = n.args[2] slice_args = (select_src_node, select_dim, select_idx, (select_idx + 1), 1) - slice_node = graph.call_function(aten.slice, slice_args) + slice_node = call_function(graph, aten.slice, slice_args) view_new_shape = [] for dim, size in enumerate(select_src_shape): if dim == select_dim: @@ -108,11 +126,11 @@ def decompose_dynamic_shape_select(gm: GraphModule): view_new_shape.append(size) else: get_dim_size_args = (select_src_node, dim) - get_dim_size_node = graph.call_function(aten.sym_size.int, - get_dim_size_args) + get_dim_size_node = call_function(graph, aten.sym_size.int, + get_dim_size_args) view_new_shape.append(get_dim_size_node) view_args = (slice_node, view_new_shape) - view_node = graph.call_function(aten.view.default, view_args) + view_node = call_function(graph, aten.view.default, view_args) n.replace_all_uses_with(view_node) graph.erase_node(n) @@ -143,8 +161,8 @@ def decompose_split_with_sizes(gm: GraphModule): decomposed_slices = [] for size in split_sizes: slice_args = (src_node, split_dim, start_idx, start_idx + size) - slice_node = graph.call_function(torch.ops.aten.slice.Tensor, - slice_args) + slice_node = call_function(graph, torch.ops.aten.slice.Tensor, + slice_args) start_idx += size decomposed_slices.append(slice_node) consumers = n.users @@ -181,8 +199,8 @@ def flatten_embedding_indices_tensor(gm: GraphModule): for dim, size in enumerate(indices_shape): if not isinstance(size, int): get_dim_size_args = (indices_node, dim) - get_dim_size_node = graph.call_function(aten.sym_size.int, - get_dim_size_args) + get_dim_size_node = call_function(graph, aten.sym_size.int, + get_dim_size_args) recover_view_shape.append(get_dim_size_node) else: flatten_mul_scale *= size @@ -191,21 +209,22 @@ def flatten_embedding_indices_tensor(gm: GraphModule): recover_view_shape.append(weight_shape[-1]) mul_args = (get_dim_size_node, flatten_mul_scale) - flatten_size_node = graph.call_function(aten.mul.Scalar, mul_args) + flatten_size_node = call_function(graph, operator.mul, mul_args) view_args = (indices_node, [flatten_size_node]) - view_node = graph.call_function(aten.view, view_args) + view_node = call_function(graph, aten.reshape, view_args) new_embedding_args = n.args[0:1] + (view_node,) if len(n.args) > 2: new_embedding_args += n.args[2:] n.args = new_embedding_args with graph.inserting_after(n): recover_view_args = (n, recover_view_shape) - recover_view_node = graph.call_function(aten.view, recover_view_args) + recover_view_node = call_function(graph, aten.reshape, + recover_view_args) n.replace_all_uses_with(recover_view_node) recover_view_node.update_arg(0, n) -def _is_no_op_slice(n): +def _is_no_op_slice(n: torch.fx.Node): assert n.op == "call_function" and n.target == aten.slice.Tensor return n.args[2] == 0 and n.args[3] == torch.iinfo(torch.int64).max @@ -267,6 +286,8 @@ def replace_dynamic_view_with_xla_op(gm: GraphModule): ''' g = gm.graph ATEN_VIEW_OPS = [ + torch.ops.aten.reshape, + torch.ops.aten.reshape.default, torch.ops.aten.view, torch.ops.aten.view.default, torch.ops.aten._unsafe_view, @@ -290,10 +311,7 @@ def replace_dynamic_view_with_xla_op(gm: GraphModule): mul_scaler = 1 sym_size_node = dynamic_src mul_node = None - if hasattr( - dynamic_src.target, - "__name__") and (dynamic_src.target.__name__ == "mul" or - dynamic_src.target.__name__ == "mul.Scalar"): + if dynamic_src.target == operator.mul: assert isinstance(dynamic_src.args[0], int) or isinstance( dynamic_src.args[1], int) mul_node = dynamic_src @@ -326,22 +344,22 @@ def dynamic_unsqueeze_to_view(gm: GraphModule): ] if len(symbolic_dims) == 0: continue - assert len(symbolic_dims) == 1, "Only 1 dimention can be symbolic." + assert len(symbolic_dims) == 1, "Only 1 dimension can be symbolic." view_args = list(src_shape) with graph.inserting_before(n): for dim in symbolic_dims: - get_size_node = graph.call_function(aten.sym_size.int, - (unsqueeze_src, dim)) + get_size_node = call_function(graph, aten.sym_size.int, + (unsqueeze_src, dim)) view_args[dim] = get_size_node squeezed_dim = n.args[1] view_args = (unsqueeze_src, view_args[:squeezed_dim] + [1] + view_args[squeezed_dim:]) - view_node = graph.call_function(aten.view, view_args) + view_node = call_function(graph, aten.view, view_args) n.replace_all_uses_with(view_node) graph.erase_node(n) -def exported_program_has_symbolic_input_shape(ep): +def exported_program_has_symbolic_input_shape(ep: torch.export.ExportedProgram): for n in ep.graph_module.graph.nodes: if n.op == "placeholder": fake_t = n.meta['val'] diff --git a/torch_xla/experimental/xla_dynamic_reshape_ops.py b/torch_xla/experimental/xla_dynamic_reshape_ops.py index e71cbd80887..a2881a64a74 100644 --- a/torch_xla/experimental/xla_dynamic_reshape_ops.py +++ b/torch_xla/experimental/xla_dynamic_reshape_ops.py @@ -6,7 +6,7 @@ from torch_xla.core.xla_model import XLA_LIB XLA_LIB.define( - "dynamic_expand(Tensor input, int[] size, Tensor src_tensor, int src_dim, int target_dim) -> Tensor" + "dynamic_expand(Tensor input, SymInt[] size, Tensor src_tensor, int src_dim, int target_dim) -> Tensor" ) @@ -61,7 +61,7 @@ def dynamic_expand_meta( XLA_LIB.define( - "dynamic_view(Tensor input, int[] size, Tensor src_tensor, int src_dim, int target_dim, float mul_scaler) -> Tensor" + "dynamic_view(Tensor input, SymInt[] size, Tensor src_tensor, int src_dim, int target_dim, float mul_scaler) -> Tensor" )