diff --git a/test/stablehlo/test_saved_model.py b/test/stablehlo/test_saved_model.py index 07e6b5b247ae..bbcca36d0856 100644 --- a/test/stablehlo/test_saved_model.py +++ b/test/stablehlo/test_saved_model.py @@ -1,17 +1,21 @@ +import os +import tempfile +import unittest + import numpy as np +import tensorflow as tf +import torch import torch_xla import torch_xla.core.xla_model as xm -from torch_xla.stablehlo import StableHLOExportOptions, exported_program_to_stablehlo -from torch_xla.tf_saved_model_integration import ( - make_tf_function, save_torch_module_as_tf_saved_model, - save_stablehlo_graph_as_tf) +from torch.export import Dim, export from torch.utils import _pytree as pytree -from torch.export import export, dynamic_dim -import torch - -import tempfile -import unittest -import tensorflow as tf +from torch_xla.stablehlo import (StableHLOExportOptions, + exported_program_to_stablehlo) +from torch_xla.tf_saved_model_integration import ( + make_tf_function, save_stablehlo_graph_as_tf, + save_torch_module_as_tf_saved_model) +from utils import (compare_exported_program_and_saved_model_result, + has_tf_package, wrap_func_as_nn_module) class StableHLOInferenceTest(unittest.TestCase): @@ -26,17 +30,14 @@ def forward(self, a, b): model = MyModule() a = torch.randn(3, 10) b = torch.randn(3, 10) - constraints = [ - dynamic_dim(a, 0), - dynamic_dim(b, 0), - dynamic_dim(a, 0) == dynamic_dim(b, 0) - ] + bs = Dim("bs") + dynamic_shapes = ({0: bs}, {0: bs}) exported = torch.export.export( model, ( a, b, - ), constraints=constraints) + ), dynamic_shapes=dynamic_shapes) shlo = exported_program_to_stablehlo(exported) with tempfile.TemporaryDirectory() as tempdir: save_stablehlo_graph_as_tf( @@ -56,18 +57,51 @@ class M(torch.nn.Module): def forward(self, a, b): return torch.sin(b) - model = M() - data = (torch.randn(4, 3, 224, 224), torch.randn(1, 100)) - output = model(*data) + m = M() + args = (torch.randn(4, 3, 224, 224), torch.randn(1, 100)) + ep = torch.export.export(m, args) + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model(m, args, tempdir) + self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + compare_exported_program_and_saved_model_result(ep, tempdir, args) + + def test_multiple_outputs(self): + + class M(torch.nn.Module): + + def forward(self, a, b): + return a + b, a * b, a, b + m = M() + args = (torch.rand((2, 3)), torch.rand((2, 3))) + ep = torch.export.export(m, args) with tempfile.TemporaryDirectory() as tempdir: - save_torch_module_as_tf_saved_model(model, data, tempdir) - loaded_m = tf.saved_model.load(tempdir) - res = loaded_m.f(data[0].detach().numpy(), data[1].detach().numpy())[0] - output2 = torch.tensor(res.numpy()) - self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + save_torch_module_as_tf_saved_model(m, args, tempdir) + self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + compare_exported_program_and_saved_model_result(ep, tempdir, args) + + def test_non_tensor_input_int(self): + m = wrap_func_as_nn_module(torch.ops.aten._softmax.default) + args = (torch.rand((2, 3, 4, 5)), -1, False) + ep = torch.export.export(m, args) + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model(m, args, tempdir) + self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + compare_exported_program_and_saved_model_result(ep, tempdir, args) + + def test_non_tensor_input_float(self): + m = wrap_func_as_nn_module(torch.ops.aten._cdist_forward) + args = (torch.rand((2, 3, 4)), torch.rand((2, 3, 4)), 2.4, 1) + ep = torch.export.export(m, args) + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model(m, args, tempdir) + self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + compare_exported_program_and_saved_model_result(ep, tempdir, args) if __name__ == '__main__': + if not has_tf_package(): + print("skip tf.saved_model tests, tf is not installed.") + sys.exit(0) test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/stablehlo/utils.py b/test/stablehlo/utils.py index c5779ea215d5..1e96282c4c25 100644 --- a/test/stablehlo/utils.py +++ b/test/stablehlo/utils.py @@ -1,4 +1,10 @@ import functools +import tempfile +from typing import Any, Dict, Tuple + +import numpy as np +import torch +from torch.utils import _pytree as pytree @functools.lru_cache @@ -8,3 +14,39 @@ def has_tf_package() -> bool: return tensorflow is not None except ImportError: return False + + +def wrap_func_as_nn_module(f): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, *args): + return f(*args) + + return M().eval() + + +def load_save_model_and_inference(path: str, args: Tuple[Any, ...]) -> Dict: + assert has_tf_package() + import tensorflow as tf + loaded_m = tf.saved_model.load(path) + tf_input = pytree.tree_map_only(torch.Tensor, + lambda x: tf.constant(x.numpy()), args) + tf_output = loaded_m.f(*tf_input) + return tf_output + + +def compare_exported_program_and_saved_model_result(ep, saved_model_path, args): + tf_output = load_save_model_and_inference(saved_model_path, args) + torch_output = ep(*args) + if not isinstance(torch_output, tuple): + torch_output = (torch_output,) + assert len(torch_output) == len(tf_output) + for idx in range(len(torch_output)): + torch_output_np = torch_output[idx].numpy() + tf_output_np = tf_output[idx].numpy() + assert torch_output_np.dtype == tf_output_np.dtype, f"torch dtype: {torch_output[idx].dtype}, tf dtype: {tf_output[idx].dtype}" + assert np.allclose(torch_output_np, tf_output_np) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index b10961694eaf..04a17d56a0f5 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -399,7 +399,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, else: signature = VariableSignature( shape=[], - dtype=str(type(arg)), + dtype=type(arg).__name__, ) unused_inputs.append((pos, signature)) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index d183851a6740..dc42f8cb4775 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -60,11 +60,16 @@ def _make_input_signatures( zip(meta.input_locations, meta.input_signature), meta.unused_inputs) if loc.type_ == stablehlo.VariableType.INPUT_ARG } + primitive_type_to_tf_type = {'int': 'int32', 'float': 'float32'} for i in range(len(input_pos_to_spec)): spec = input_pos_to_spec[i] shape = _get_shape_with_dynamic(spec) yield tf.TensorSpec( - shape=shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}') + shape=shape, + dtype=getattr( + tf, primitive_type_to_tf_type[spec.dtype] + if spec.dtype in primitive_type_to_tf_type else spec.dtype), + name=f'args_{i}') def _mangle_tf_root_scope_name(name):