Skip to content

Commit

Permalink
add some more basic constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Apr 4, 2024
1 parent 99edc32 commit d58da14
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
5 changes: 5 additions & 0 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class TestTorchFunctions(parameterized.TestCase):
('tensor_empty', lambda: torch.tensor([],)),
('ones_2d', lambda: torch.ones(2, 3)),
('ones_1d', lambda: torch.ones(5)),
('zeros_2d', lambda: torch.zeros(2, 3)),
('zeros_1d', lambda: torch.zeros(5)),
('eye_3x3', lambda: torch.eye(3)),
('eye_4x2', lambda: torch.eye(4, 2)),
('full_tuple', lambda: torch.full((2, 3), 3.141592)),
)
def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
expected = func()
Expand Down
55 changes: 39 additions & 16 deletions experimental/torch_xla2/torch_xla2/functions.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,60 @@
"""Tensor constructor overrides"""
import functools
import logging
from typing import Callable, Optional, ParamSpec, Sequence
import warnings

import jax
import torch
import jax.numpy as jnp
from torch_xla2 import tensor

# TODO: registry
# TODO: correct types
def ones(*size: int):
registry = {}

P = ParamSpec('P')

def register_function(torch_func: Callable[P, torch.Tensor]):
def decorator(jax_impl: Callable[P, jax.Array]):
@functools.wraps(torch_func)
def wrapper(*args: P.args, **kwargs: P.kwargs):
return jax_impl(*args, **kwargs)

registry[torch_func] = jax_impl

return wrapper
return decorator


@register_function(torch.tensor)
def _tensor(data, *args, **kwargs):
return jnp.array(data)

@register_function(torch.ones)
def _ones(*size: int, **kwargs):
return jnp.ones(size)

@register_function(torch.zeros)
def _zeros(*size: int, **kwargs):
return jnp.zeros(size)

@register_function(torch.eye)
def _eye(n: int, m: Optional[int] = None, **kwargs):
return jnp.eye(n, m)

fns = {
torch.tensor: jnp.array,
torch.ones: ones,
# torch.zeros: jnp.zeros,
# torch.arange: jnp.arange,
# torch.linspace: jnp.linspace,
# torch.logspace: jnp.logspace,
# torch.empty: jnp.empty,
# torch.eye: jnp.eye,
# torch.full: jnp.full,
}
@register_function(torch.full)
def _full(size: Sequence[int], fill_value, **kwargs):
# TODO: handle torch.Size
return jnp.full(size, fill_value)

class XLAFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:
jax_func = fns.get(func)
jax_func = registry.get(func)
if not jax_func:
logging.warn(f'Falling back to default implementation of {func.__name__}')
return func(*args, **(kwargs or {}))

if kwargs:
warnings.warn(f'kwargs not implemented for {kwargs}')

return tensor.wrap(jax_func(*tensor.unwrap(args)))
# TODO: unwrap args here or in implementations?
return tensor.wrap(jax_func(*args))

0 comments on commit d58da14

Please sign in to comment.