Skip to content

Commit

Permalink
save work
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed Mar 23, 2024
1 parent bbbe01d commit 533f3df
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 26 deletions.
56 changes: 46 additions & 10 deletions test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch_xla.core.xla_model as xm
from torch.export import Dim, export
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.stablehlo import exported_program_to_stablehlo, StableHLOExportOptions

try:
from torch_xla.tf_saved_model_integration import \
Expand Down Expand Up @@ -740,20 +740,56 @@ def forward(self, x):
dynamic_shapes = ({0: Dim("bs")},)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
shlo_module = exported_program_to_stablehlo(ep)
options = StableHLOExportOptions(dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep, options)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x43681xf32>",
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
tf_out = load_save_model_and_inference(tempdir, args)
self.assertTrue(
np.allclose(out1.detach().numpy(), tf_out[0].numpy(), atol=1e-05))
# if has_tf_package():
# with tempfile.TemporaryDirectory() as tempdir:
# save_torch_module_as_tf_saved_model(
# m, args, tempdir, dynamic_shapes=dynamic_shapes)
# self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
# tf_out = load_save_model_and_inference(tempdir, args)
# self.assertTrue(
# np.allclose(out1.detach().numpy(), tf_out[0].numpy(), atol=1e-05))

def test_dynamic_view_2(self):
import torch_xla.experimental.xla_dynamic_reshape_ops
class M(torch.nn.Module):

def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 5, [16, 16])

def forward(self, x):
x = self.conv(x)
return torch.ops.xla.dynamic_view(x, [x.shape[0], x.shape[1], -1], x, 0, 0, 1)

m = M().eval()
args = (torch.rand((10, 3, 224, 224)),)
dynamic_shapes = ({0: Dim("bs")},)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
print(ep)
# out1 = ep.module()(*args)
# options = StableHLOExportOptions()
# options.dynamic_shapes = dynamic_shapes
# shlo_module = exported_program_to_stablehlo(ep, options)
# shlo_text = shlo_module.get_stablehlo_text()
# self.assertTrue(
# re.search(
# r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x43681xf32>",
# shlo_text) is not None)
# if has_tf_package():
# with tempfile.TemporaryDirectory() as tempdir:
# save_torch_module_as_tf_saved_model(
# m, args, tempdir, dynamic_shapes=dynamic_shapes)
# self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
# tf_out = load_save_model_and_inference(tempdir, args)
# self.assertTrue(
# np.allclose(out1.detach().numpy(), tf_out[0].numpy(), atol=1e-05))

@unittest.skip("Cannot generate aten.sym_numel in the exported program.")
def test_dynamic_view_sym_numel(self):
Expand Down
54 changes: 43 additions & 11 deletions torch_xla/experimental/unbounded_dynamism_export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools

import torch
import torch_xla.experimental.xla_dynamic_reshape_ops
from torch.export import export
from torch.export import Dim, export
from torch.fx import Graph, GraphModule, subgraph_rewriter

aten = torch.ops.aten
Expand All @@ -15,19 +17,39 @@ def forward(self, *args):

return M()

@functools.lru_cache
def _get_built_in_mul_op():

class M(torch.nn.Module):

def forward(self, x):
return x.shape[0] * 10

args = (torch.rand(2, 3),)
dynamic_shapes = ({0: Dim("dim")},)
m = M()
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
built_in_mul = None
for n in ep.graph_module.graph.nodes:
if hasattr(n.target, "__name__") and n.target.__name__ == "mul":
built_in_mul = n.target
break
assert built_in_mul is not None
return built_in_mul


def native_layer_norm_impl(input, normalized_shape, weight, bias, eps):
mean = torch.mean(input, -1, keepdim=True)
rstd = torch.rsqrt(torch.var(input, -1, keepdim=True, correction=0) + eps)
out = ((input + mean * -1) * rstd) # sub doesn't support unbounded dynamism
out = (input - mean) * rstd
out = out * weight + bias
return out


def native_group_norm_impl(input, weight, bias, N, C, HxW, group, eps):
mean = torch.mean(input, -1, keepdim=True)
rstd = torch.rsqrt(torch.var(input, -1, keepdim=True, correction=0) + eps)
out = ((input + mean * -1) * rstd) # sub doesn't support unbounded dynamism
out = (input - mean) * rstd
weight = weight.unsqueeze(1)
bias = bias.unsqueeze(1)
out = out * weight + bias
Expand Down Expand Up @@ -57,7 +79,7 @@ def decompose_dynamic_native_layer_norm(gm):
gm,
native_layer_norm_pattern,
native_layer_norm_impl,
ignore_literals=False)
ignore_literals=True)
# Only support normalize along the last dim now. Check if replacement is valid.
for matches in replaced_patterns:
for pattern_n, match_n in matches.nodes_map.items():
Expand Down Expand Up @@ -113,7 +135,6 @@ def decompose_dynamic_shape_select(gm: GraphModule):
view_new_shape.append(get_dim_size_node)
view_args = (slice_node, view_new_shape)
view_node = graph.call_function(aten.view.default, view_args)
view_node.meta['val'] = n.meta['val']
n.replace_all_uses_with(view_node)
graph.erase_node(n)

Expand Down Expand Up @@ -193,17 +214,18 @@ def flatten_embedding_indices_tensor(gm: GraphModule):
recover_view_shape.append(weight_shape[-1])

mul_args = (get_dim_size_node, flatten_mul_scale)
flatten_size_node = graph.call_function(aten.mul.Scalar, mul_args)
view_args = (indices_node, [flatten_size_node])
view_node = graph.call_function(aten.view, view_args)
flatten_size_node = graph.call_function(_get_built_in_mul_op(),
mul_args)
# view_args = (indices_node, [flatten_size_node])
view_args = (indices_node, [-1])
view_node = graph.call_function(aten._unsafe_view, view_args)
new_embedding_args = n.args[0:1] + (view_node,)
if len(n.args) > 2:
new_embedding_args += n.args[2:]
n.args = new_embedding_args
with graph.inserting_after(n):
recover_view_args = (n, recover_view_shape)
recover_view_node = graph.call_function(aten.view, recover_view_args)
recover_view_node.meta['val'] = n.meta['val']
recover_view_node = graph.call_function(aten._unsafe_view, recover_view_args)
n.replace_all_uses_with(recover_view_node)
recover_view_node.update_arg(0, n)

Expand Down Expand Up @@ -279,6 +301,7 @@ def replace_dynamic_view_with_xla_op(gm: GraphModule):
if n.target in ATEN_VIEW_OPS:
view_dims = n.args[1]
sym_sizes = []
# TODO: consider view([-1])
for dim, node in enumerate(view_dims):
if not isinstance(node, int):
sym_sizes.append((dim, node))
Expand Down Expand Up @@ -358,7 +381,7 @@ def exported_program_has_symbolic_input_shape(ep):
return False


def process_exported_program_with_symbolic_input(ep):
def process_exported_program_with_symbolic_input(ep, dynamic_shapes):
passes = [
decompose_dynamic_shape_select,
decompose_split_with_sizes,
Expand All @@ -375,3 +398,12 @@ def process_exported_program_with_symbolic_input(ep):
ep.graph_module.graph.eliminate_dead_code()
ep.graph_module.graph.lint()
ep.graph_module.recompile()
example_args, example_kwargs = ep.example_inputs
ep = export(
ep.module(),
example_args,
example_kwargs,
dynamic_shapes=dynamic_shapes)
print("after pass{}".format(p.__name__))
print(ep)
return ep
4 changes: 2 additions & 2 deletions torch_xla/experimental/xla_dynamic_reshape_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch_xla.core.xla_model import XLA_LIB

XLA_LIB.define(
"dynamic_expand(Tensor input, int[] size, Tensor src_tensor, int src_dim, int target_dim) -> Tensor"
"dynamic_expand(Tensor input, SymInt[] size, Tensor src_tensor, int src_dim, int target_dim) -> Tensor"
)


Expand Down Expand Up @@ -61,7 +61,7 @@ def dynamic_expand_meta(


XLA_LIB.define(
"dynamic_view(Tensor input, int[] size, Tensor src_tensor, int src_dim, int target_dim, float mul_scaler) -> Tensor"
"dynamic_view(Tensor input, SymInt[] size, Tensor src_tensor, int src_dim, int target_dim, float mul_scaler) -> Tensor"
)


Expand Down
6 changes: 4 additions & 2 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import shutil
import os
import re
from typing import List, Tuple, Optional, Mapping, Any, Dict
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import dataclasses

import numpy as np
Expand Down Expand Up @@ -60,6 +60,7 @@ class StableHLOExportOptions:
override_tracing_arguments: Optional[Tuple[Any]] = None
override_tracing_kwargs: Optional[Mapping[str, Any]] = None
save_weights: bool = True
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None


class StableHLOGraphModule:
Expand Down Expand Up @@ -299,7 +300,8 @@ def _exported_program_to_stablehlo_bundle(exported_model,
exported_model = exported_model.run_decompositions()
exported_model = exported_model.run_decompositions(_extra_decompositions)
if exported_program_has_symbolic_input_shape(exported_model):
process_exported_program_with_symbolic_input(exported_model)
assert options.dynamic_shapes is not None, "dynamic shapes arg used for export must be passed in with StableHLOExportOptions now."
exported_model = process_exported_program_with_symbolic_input(exported_model, options.dynamic_shapes)
args, kwargs = exported_model.example_inputs

assert len(kwargs) == 0, "Export to stablehlo doesnt support kwargs yet."
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def save_torch_module_as_tf_saved_model(
"""
exported = torch.export.export(
torch_model, args, dynamic_shapes=dynamic_shapes)
options = stablehlo.StableHLOExportOptions(override_tracing_arguments=args)
options = stablehlo.StableHLOExportOptions(override_tracing_arguments=args, dynamic_shapes=dynamic_shapes)
stablehlo_model = stablehlo.exported_program_to_stablehlo(exported, options)
save_stablehlo_graph_as_tf(stablehlo_model, saved_model_dir, serving_key,
function_alias)
Expand Down

0 comments on commit 533f3df

Please sign in to comment.