From 50ac7dd063e6bcbfa90575b9c007fa821843c79b Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Apr 2024 22:22:38 +0000 Subject: [PATCH] Add `torch.tensor` constructor --- .../torch_xla2/test/test_functions.py | 26 +++++++++++++++++ .../torch_xla2/torch_xla2/functions.py | 28 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 experimental/torch_xla2/test/test_functions.py create mode 100644 experimental/torch_xla2/torch_xla2/functions.py diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py new file mode 100644 index 00000000000..89307fb8153 --- /dev/null +++ b/experimental/torch_xla2/test/test_functions.py @@ -0,0 +1,26 @@ +from absl.testing import absltest +from absl.testing import parameterized +import torch +import torch_xla2 +import torch_xla2.functions +import torch_xla2.tensor + +class TestTorchFunctions(parameterized.TestCase): + @parameterized.parameters( + [([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]],)], + [([0, 1],)], + [(3.14159,)], + [([],)], + ) + def test_tensor(self, args): + expected = torch.tensor(*args) + + with torch_xla2.functions.XLAFunctionMode(): + actual = torch.tensor(*args) + + # TODO: dtype is actually important + torch.testing.assert_close(torch_xla2.tensor.j2t(actual), expected, check_dtype=False) + + +if __name__ == '__main__': + absltest.main() diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py new file mode 100644 index 00000000000..fc79ddaadcc --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -0,0 +1,28 @@ +"""Tensor constructor overrides""" +import warnings + +import torch +import jax.numpy as jnp + +fns = { + torch.tensor: jnp.array, + # torch.ones: jnp.ones, + # torch.zeros: jnp.zeros, + # torch.arange: jnp.arange, + # torch.linspace: jnp.linspace, + # torch.logspace: jnp.logspace, + # torch.empty: jnp.empty, + # torch.eye: jnp.eye, + # torch.full: jnp.full, +} + +class XLAFunctionMode(torch.overrides.TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + jax_func = fns.get(func) + if not jax_func: + raise NotImplementedError(f'No jax function found for {func.__name__}') + + if kwargs: + warnings.warn(f'kwargs not implemented for {kwargs}') + + return jax_func(*args)