From 99edc3285fcef979e9acd6f3937cedb2c7b131d6 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 3 Apr 2024 23:53:54 +0000 Subject: [PATCH] ones --- experimental/torch_xla2/test/test_functions.py | 8 +++++--- experimental/torch_xla2/torch_xla2/functions.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index 4e4f5d51b15..b81bd2af4ab 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -7,12 +7,14 @@ import torch_xla2.tensor class TestTorchFunctions(parameterized.TestCase): - @parameterized.named_parameters([ - ('tensor', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])), + @parameterized.named_parameters( + ('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])), ('tensor_1d', lambda: torch.tensor([0, 1],)), ('tensor_scalar', lambda: torch.tensor(3.14159,)), ('tensor_empty', lambda: torch.tensor([],)), - ]) + ('ones_2d', lambda: torch.ones(2, 3)), + ('ones_1d', lambda: torch.ones(5)), + ) def test_tensor_constructor(self, func: Callable[[], torch.Tensor]): expected = func() diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index cd486afb0f7..10e74ec13f5 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -1,15 +1,20 @@ """Tensor constructor overrides""" import logging -from typing import Union import warnings import torch import jax.numpy as jnp from torch_xla2 import tensor +# TODO: registry +# TODO: correct types +def ones(*size: int): + return jnp.ones(size) + + fns = { torch.tensor: jnp.array, - # torch.ones: jnp.ones, + torch.ones: ones, # torch.zeros: jnp.zeros, # torch.arange: jnp.arange, # torch.linspace: jnp.linspace, @@ -24,7 +29,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor: jax_func = fns.get(func) if not jax_func: logging.warn(f'Falling back to default implementation of {func.__name__}') - func(*args, **kwargs) + return func(*args, **(kwargs or {})) if kwargs: warnings.warn(f'kwargs not implemented for {kwargs}')