Skip to content

Commit

Permalink
Add torch.tensor constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Apr 2, 2024
1 parent 8cf394f commit 50ac7dd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
26 changes: 26 additions & 0 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from absl.testing import absltest
from absl.testing import parameterized
import torch
import torch_xla2
import torch_xla2.functions
import torch_xla2.tensor

class TestTorchFunctions(parameterized.TestCase):
@parameterized.parameters(
[([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]],)],
[([0, 1],)],
[(3.14159,)],
[([],)],
)
def test_tensor(self, args):
expected = torch.tensor(*args)

with torch_xla2.functions.XLAFunctionMode():
actual = torch.tensor(*args)

# TODO: dtype is actually important
torch.testing.assert_close(torch_xla2.tensor.j2t(actual), expected, check_dtype=False)


if __name__ == '__main__':
absltest.main()
28 changes: 28 additions & 0 deletions experimental/torch_xla2/torch_xla2/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Tensor constructor overrides"""
import warnings

import torch
import jax.numpy as jnp

fns = {
torch.tensor: jnp.array,
# torch.ones: jnp.ones,
# torch.zeros: jnp.zeros,
# torch.arange: jnp.arange,
# torch.linspace: jnp.linspace,
# torch.logspace: jnp.logspace,
# torch.empty: jnp.empty,
# torch.eye: jnp.eye,
# torch.full: jnp.full,
}

class XLAFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
jax_func = fns.get(func)
if not jax_func:
raise NotImplementedError(f'No jax function found for {func.__name__}')

if kwargs:
warnings.warn(f'kwargs not implemented for {kwargs}')

return jax_func(*args)

0 comments on commit 50ac7dd

Please sign in to comment.