Skip to content

Commit

Permalink
ones
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Apr 3, 2024
1 parent c06a0c8 commit 99edc32
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
8 changes: 5 additions & 3 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import torch_xla2.tensor

class TestTorchFunctions(parameterized.TestCase):
@parameterized.named_parameters([
('tensor', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])),
@parameterized.named_parameters(
('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])),
('tensor_1d', lambda: torch.tensor([0, 1],)),
('tensor_scalar', lambda: torch.tensor(3.14159,)),
('tensor_empty', lambda: torch.tensor([],)),
])
('ones_2d', lambda: torch.ones(2, 3)),
('ones_1d', lambda: torch.ones(5)),
)
def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
expected = func()

Expand Down
11 changes: 8 additions & 3 deletions experimental/torch_xla2/torch_xla2/functions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""Tensor constructor overrides"""
import logging
from typing import Union
import warnings

import torch
import jax.numpy as jnp
from torch_xla2 import tensor

# TODO: registry
# TODO: correct types
def ones(*size: int):
return jnp.ones(size)


fns = {
torch.tensor: jnp.array,
# torch.ones: jnp.ones,
torch.ones: ones,
# torch.zeros: jnp.zeros,
# torch.arange: jnp.arange,
# torch.linspace: jnp.linspace,
Expand All @@ -24,7 +29,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:
jax_func = fns.get(func)
if not jax_func:
logging.warn(f'Falling back to default implementation of {func.__name__}')
func(*args, **kwargs)
return func(*args, **(kwargs or {}))

if kwargs:
warnings.warn(f'kwargs not implemented for {kwargs}')
Expand Down

0 comments on commit 99edc32

Please sign in to comment.