Skip to content

Commit

Permalink
handle non-tensor input in tf saved_model input placeholder (pytorch#…
Browse files Browse the repository at this point in the history
…6640)

Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
2 people authored and amithrm committed Mar 1, 2024
1 parent da4993c commit 62ad221
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 26 deletions.
82 changes: 58 additions & 24 deletions test/stablehlo/test_saved_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import os
import tempfile
import unittest

import numpy as np
import tensorflow as tf
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.stablehlo import StableHLOExportOptions, exported_program_to_stablehlo
from torch_xla.tf_saved_model_integration import (
make_tf_function, save_torch_module_as_tf_saved_model,
save_stablehlo_graph_as_tf)
from torch.export import Dim, export
from torch.utils import _pytree as pytree
from torch.export import export, dynamic_dim
import torch

import tempfile
import unittest
import tensorflow as tf
from torch_xla.stablehlo import (StableHLOExportOptions,
exported_program_to_stablehlo)
from torch_xla.tf_saved_model_integration import (
make_tf_function, save_stablehlo_graph_as_tf,
save_torch_module_as_tf_saved_model)
from utils import (compare_exported_program_and_saved_model_result,
has_tf_package, wrap_func_as_nn_module)


class StableHLOInferenceTest(unittest.TestCase):
Expand All @@ -26,17 +30,14 @@ def forward(self, a, b):
model = MyModule()
a = torch.randn(3, 10)
b = torch.randn(3, 10)
constraints = [
dynamic_dim(a, 0),
dynamic_dim(b, 0),
dynamic_dim(a, 0) == dynamic_dim(b, 0)
]
bs = Dim("bs")
dynamic_shapes = ({0: bs}, {0: bs})

exported = torch.export.export(
model, (
a,
b,
), constraints=constraints)
), dynamic_shapes=dynamic_shapes)
shlo = exported_program_to_stablehlo(exported)
with tempfile.TemporaryDirectory() as tempdir:
save_stablehlo_graph_as_tf(
Expand All @@ -56,18 +57,51 @@ class M(torch.nn.Module):
def forward(self, a, b):
return torch.sin(b)

model = M()
data = (torch.randn(4, 3, 224, 224), torch.randn(1, 100))
output = model(*data)
m = M()
args = (torch.randn(4, 3, 224, 224), torch.randn(1, 100))
ep = torch.export.export(m, args)
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(m, args, tempdir)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_multiple_outputs(self):

class M(torch.nn.Module):

def forward(self, a, b):
return a + b, a * b, a, b

m = M()
args = (torch.rand((2, 3)), torch.rand((2, 3)))
ep = torch.export.export(m, args)
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(model, data, tempdir)
loaded_m = tf.saved_model.load(tempdir)
res = loaded_m.f(data[0].detach().numpy(), data[1].detach().numpy())[0]
output2 = torch.tensor(res.numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))
save_torch_module_as_tf_saved_model(m, args, tempdir)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_non_tensor_input_int(self):
m = wrap_func_as_nn_module(torch.ops.aten._softmax.default)
args = (torch.rand((2, 3, 4, 5)), -1, False)
ep = torch.export.export(m, args)
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(m, args, tempdir)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_non_tensor_input_float(self):
m = wrap_func_as_nn_module(torch.ops.aten._cdist_forward)
args = (torch.rand((2, 3, 4)), torch.rand((2, 3, 4)), 2.4, 1)
ep = torch.export.export(m, args)
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(m, args, tempdir)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)


if __name__ == '__main__':
if not has_tf_package():
print("skip tf.saved_model tests, tf is not installed.")
sys.exit(0)
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
42 changes: 42 additions & 0 deletions test/stablehlo/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import functools
import tempfile
from typing import Any, Dict, Tuple

import numpy as np
import torch
from torch.utils import _pytree as pytree


@functools.lru_cache
Expand All @@ -8,3 +14,39 @@ def has_tf_package() -> bool:
return tensorflow is not None
except ImportError:
return False


def wrap_func_as_nn_module(f):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, *args):
return f(*args)

return M().eval()


def load_save_model_and_inference(path: str, args: Tuple[Any, ...]) -> Dict:
assert has_tf_package()
import tensorflow as tf
loaded_m = tf.saved_model.load(path)
tf_input = pytree.tree_map_only(torch.Tensor,
lambda x: tf.constant(x.numpy()), args)
tf_output = loaded_m.f(*tf_input)
return tf_output


def compare_exported_program_and_saved_model_result(ep, saved_model_path, args):
tf_output = load_save_model_and_inference(saved_model_path, args)
torch_output = ep(*args)
if not isinstance(torch_output, tuple):
torch_output = (torch_output,)
assert len(torch_output) == len(tf_output)
for idx in range(len(torch_output)):
torch_output_np = torch_output[idx].numpy()
tf_output_np = tf_output[idx].numpy()
assert torch_output_np.dtype == tf_output_np.dtype, f"torch dtype: {torch_output[idx].dtype}, tf dtype: {tf_output[idx].dtype}"
assert np.allclose(torch_output_np, tf_output_np)
2 changes: 1 addition & 1 deletion torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
else:
signature = VariableSignature(
shape=[],
dtype=str(type(arg)),
dtype=type(arg).__name__,
)

unused_inputs.append((pos, signature))
Expand Down
7 changes: 6 additions & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,16 @@ def _make_input_signatures(
zip(meta.input_locations, meta.input_signature), meta.unused_inputs)
if loc.type_ == stablehlo.VariableType.INPUT_ARG
}
primitive_type_to_tf_type = {'int': 'int32', 'float': 'float32'}
for i in range(len(input_pos_to_spec)):
spec = input_pos_to_spec[i]
shape = _get_shape_with_dynamic(spec)
yield tf.TensorSpec(
shape=shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}')
shape=shape,
dtype=getattr(
tf, primitive_type_to_tf_type[spec.dtype]
if spec.dtype in primitive_type_to_tf_type else spec.dtype),
name=f'args_{i}')


def _mangle_tf_root_scope_name(name):
Expand Down

0 comments on commit 62ad221

Please sign in to comment.