From 533f3dfb537588c6ec58d2f92358c1c3c4c7562d Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 23 Mar 2024 00:52:07 +0000 Subject: [PATCH] save work --- test/stablehlo/test_unbounded_dynamism.py | 56 +++++++++++++++---- .../experimental/unbounded_dynamism_export.py | 54 ++++++++++++++---- .../experimental/xla_dynamic_reshape_ops.py | 4 +- torch_xla/stablehlo.py | 6 +- torch_xla/tf_saved_model_integration.py | 2 +- 5 files changed, 96 insertions(+), 26 deletions(-) diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index 33a2669c2c61..dc0e6b308ab0 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -8,7 +8,7 @@ import torch import torch_xla.core.xla_model as xm from torch.export import Dim, export -from torch_xla.stablehlo import exported_program_to_stablehlo +from torch_xla.stablehlo import exported_program_to_stablehlo, StableHLOExportOptions try: from torch_xla.tf_saved_model_integration import \ @@ -740,20 +740,56 @@ def forward(self, x): dynamic_shapes = ({0: Dim("bs")},) ep = export(m, args, dynamic_shapes=dynamic_shapes) out1 = ep.module()(*args) - shlo_module = exported_program_to_stablehlo(ep) + options = StableHLOExportOptions(dynamic_shapes=dynamic_shapes) + shlo_module = exported_program_to_stablehlo(ep, options) shlo_text = shlo_module.get_stablehlo_text() self.assertTrue( re.search( r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x43681xf32>", shlo_text) is not None) - if has_tf_package(): - with tempfile.TemporaryDirectory() as tempdir: - save_torch_module_as_tf_saved_model( - m, args, tempdir, dynamic_shapes=dynamic_shapes) - self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) - tf_out = load_save_model_and_inference(tempdir, args) - self.assertTrue( - np.allclose(out1.detach().numpy(), tf_out[0].numpy(), atol=1e-05)) + # if has_tf_package(): + # with tempfile.TemporaryDirectory() as tempdir: + # save_torch_module_as_tf_saved_model( + # m, args, tempdir, dynamic_shapes=dynamic_shapes) + # self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + # tf_out = load_save_model_and_inference(tempdir, args) + # self.assertTrue( + # np.allclose(out1.detach().numpy(), tf_out[0].numpy(), atol=1e-05)) + + def test_dynamic_view_2(self): + import torch_xla.experimental.xla_dynamic_reshape_ops + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 5, [16, 16]) + + def forward(self, x): + x = self.conv(x) + return torch.ops.xla.dynamic_view(x, [x.shape[0], x.shape[1], -1], x, 0, 0, 1) + + m = M().eval() + args = (torch.rand((10, 3, 224, 224)),) + dynamic_shapes = ({0: Dim("bs")},) + ep = export(m, args, dynamic_shapes=dynamic_shapes) + print(ep) + # out1 = ep.module()(*args) + # options = StableHLOExportOptions() + # options.dynamic_shapes = dynamic_shapes + # shlo_module = exported_program_to_stablehlo(ep, options) + # shlo_text = shlo_module.get_stablehlo_text() + # self.assertTrue( + # re.search( + # r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x43681xf32>", + # shlo_text) is not None) + # if has_tf_package(): + # with tempfile.TemporaryDirectory() as tempdir: + # save_torch_module_as_tf_saved_model( + # m, args, tempdir, dynamic_shapes=dynamic_shapes) + # self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + # tf_out = load_save_model_and_inference(tempdir, args) + # self.assertTrue( + # np.allclose(out1.detach().numpy(), tf_out[0].numpy(), atol=1e-05)) @unittest.skip("Cannot generate aten.sym_numel in the exported program.") def test_dynamic_view_sym_numel(self): diff --git a/torch_xla/experimental/unbounded_dynamism_export.py b/torch_xla/experimental/unbounded_dynamism_export.py index d517e79276c9..e0cc8e4c0139 100644 --- a/torch_xla/experimental/unbounded_dynamism_export.py +++ b/torch_xla/experimental/unbounded_dynamism_export.py @@ -1,6 +1,8 @@ +import functools + import torch import torch_xla.experimental.xla_dynamic_reshape_ops -from torch.export import export +from torch.export import Dim, export from torch.fx import Graph, GraphModule, subgraph_rewriter aten = torch.ops.aten @@ -15,11 +17,31 @@ def forward(self, *args): return M() +@functools.lru_cache +def _get_built_in_mul_op(): + + class M(torch.nn.Module): + + def forward(self, x): + return x.shape[0] * 10 + + args = (torch.rand(2, 3),) + dynamic_shapes = ({0: Dim("dim")},) + m = M() + ep = export(m, args=args, dynamic_shapes=dynamic_shapes) + built_in_mul = None + for n in ep.graph_module.graph.nodes: + if hasattr(n.target, "__name__") and n.target.__name__ == "mul": + built_in_mul = n.target + break + assert built_in_mul is not None + return built_in_mul + def native_layer_norm_impl(input, normalized_shape, weight, bias, eps): mean = torch.mean(input, -1, keepdim=True) rstd = torch.rsqrt(torch.var(input, -1, keepdim=True, correction=0) + eps) - out = ((input + mean * -1) * rstd) # sub doesn't support unbounded dynamism + out = (input - mean) * rstd out = out * weight + bias return out @@ -27,7 +49,7 @@ def native_layer_norm_impl(input, normalized_shape, weight, bias, eps): def native_group_norm_impl(input, weight, bias, N, C, HxW, group, eps): mean = torch.mean(input, -1, keepdim=True) rstd = torch.rsqrt(torch.var(input, -1, keepdim=True, correction=0) + eps) - out = ((input + mean * -1) * rstd) # sub doesn't support unbounded dynamism + out = (input - mean) * rstd weight = weight.unsqueeze(1) bias = bias.unsqueeze(1) out = out * weight + bias @@ -57,7 +79,7 @@ def decompose_dynamic_native_layer_norm(gm): gm, native_layer_norm_pattern, native_layer_norm_impl, - ignore_literals=False) + ignore_literals=True) # Only support normalize along the last dim now. Check if replacement is valid. for matches in replaced_patterns: for pattern_n, match_n in matches.nodes_map.items(): @@ -113,7 +135,6 @@ def decompose_dynamic_shape_select(gm: GraphModule): view_new_shape.append(get_dim_size_node) view_args = (slice_node, view_new_shape) view_node = graph.call_function(aten.view.default, view_args) - view_node.meta['val'] = n.meta['val'] n.replace_all_uses_with(view_node) graph.erase_node(n) @@ -193,17 +214,18 @@ def flatten_embedding_indices_tensor(gm: GraphModule): recover_view_shape.append(weight_shape[-1]) mul_args = (get_dim_size_node, flatten_mul_scale) - flatten_size_node = graph.call_function(aten.mul.Scalar, mul_args) - view_args = (indices_node, [flatten_size_node]) - view_node = graph.call_function(aten.view, view_args) + flatten_size_node = graph.call_function(_get_built_in_mul_op(), + mul_args) + # view_args = (indices_node, [flatten_size_node]) + view_args = (indices_node, [-1]) + view_node = graph.call_function(aten._unsafe_view, view_args) new_embedding_args = n.args[0:1] + (view_node,) if len(n.args) > 2: new_embedding_args += n.args[2:] n.args = new_embedding_args with graph.inserting_after(n): recover_view_args = (n, recover_view_shape) - recover_view_node = graph.call_function(aten.view, recover_view_args) - recover_view_node.meta['val'] = n.meta['val'] + recover_view_node = graph.call_function(aten._unsafe_view, recover_view_args) n.replace_all_uses_with(recover_view_node) recover_view_node.update_arg(0, n) @@ -279,6 +301,7 @@ def replace_dynamic_view_with_xla_op(gm: GraphModule): if n.target in ATEN_VIEW_OPS: view_dims = n.args[1] sym_sizes = [] + # TODO: consider view([-1]) for dim, node in enumerate(view_dims): if not isinstance(node, int): sym_sizes.append((dim, node)) @@ -358,7 +381,7 @@ def exported_program_has_symbolic_input_shape(ep): return False -def process_exported_program_with_symbolic_input(ep): +def process_exported_program_with_symbolic_input(ep, dynamic_shapes): passes = [ decompose_dynamic_shape_select, decompose_split_with_sizes, @@ -375,3 +398,12 @@ def process_exported_program_with_symbolic_input(ep): ep.graph_module.graph.eliminate_dead_code() ep.graph_module.graph.lint() ep.graph_module.recompile() + example_args, example_kwargs = ep.example_inputs + ep = export( + ep.module(), + example_args, + example_kwargs, + dynamic_shapes=dynamic_shapes) + print("after pass{}".format(p.__name__)) + print(ep) + return ep diff --git a/torch_xla/experimental/xla_dynamic_reshape_ops.py b/torch_xla/experimental/xla_dynamic_reshape_ops.py index e71cbd808876..a2881a64a746 100644 --- a/torch_xla/experimental/xla_dynamic_reshape_ops.py +++ b/torch_xla/experimental/xla_dynamic_reshape_ops.py @@ -6,7 +6,7 @@ from torch_xla.core.xla_model import XLA_LIB XLA_LIB.define( - "dynamic_expand(Tensor input, int[] size, Tensor src_tensor, int src_dim, int target_dim) -> Tensor" + "dynamic_expand(Tensor input, SymInt[] size, Tensor src_tensor, int src_dim, int target_dim) -> Tensor" ) @@ -61,7 +61,7 @@ def dynamic_expand_meta( XLA_LIB.define( - "dynamic_view(Tensor input, int[] size, Tensor src_tensor, int src_dim, int target_dim, float mul_scaler) -> Tensor" + "dynamic_view(Tensor input, SymInt[] size, Tensor src_tensor, int src_dim, int target_dim, float mul_scaler) -> Tensor" ) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index f470e42ed275..c15fd634797d 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -6,7 +6,7 @@ import shutil import os import re -from typing import List, Tuple, Optional, Mapping, Any, Dict +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import dataclasses import numpy as np @@ -60,6 +60,7 @@ class StableHLOExportOptions: override_tracing_arguments: Optional[Tuple[Any]] = None override_tracing_kwargs: Optional[Mapping[str, Any]] = None save_weights: bool = True + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None class StableHLOGraphModule: @@ -299,7 +300,8 @@ def _exported_program_to_stablehlo_bundle(exported_model, exported_model = exported_model.run_decompositions() exported_model = exported_model.run_decompositions(_extra_decompositions) if exported_program_has_symbolic_input_shape(exported_model): - process_exported_program_with_symbolic_input(exported_model) + assert options.dynamic_shapes is not None, "dynamic shapes arg used for export must be passed in with StableHLOExportOptions now." + exported_model = process_exported_program_with_symbolic_input(exported_model, options.dynamic_shapes) args, kwargs = exported_model.example_inputs assert len(kwargs) == 0, "Export to stablehlo doesnt support kwargs yet." diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index bbc7800535c1..564d0cc083a4 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -159,7 +159,7 @@ def save_torch_module_as_tf_saved_model( """ exported = torch.export.export( torch_model, args, dynamic_shapes=dynamic_shapes) - options = stablehlo.StableHLOExportOptions(override_tracing_arguments=args) + options = stablehlo.StableHLOExportOptions(override_tracing_arguments=args, dynamic_shapes=dynamic_shapes) stablehlo_model = stablehlo.exported_program_to_stablehlo(exported, options) save_stablehlo_graph_as_tf(stablehlo_model, saved_model_dir, serving_key, function_alias)