From d58da1438a8334aac2555ea08b2900f6b129ccbd Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 4 Apr 2024 23:03:24 +0000 Subject: [PATCH] add some more basic constructors --- .../torch_xla2/test/test_functions.py | 5 ++ .../torch_xla2/torch_xla2/functions.py | 55 +++++++++++++------ 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index b81bd2af4ab..6608de04881 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -14,6 +14,11 @@ class TestTorchFunctions(parameterized.TestCase): ('tensor_empty', lambda: torch.tensor([],)), ('ones_2d', lambda: torch.ones(2, 3)), ('ones_1d', lambda: torch.ones(5)), + ('zeros_2d', lambda: torch.zeros(2, 3)), + ('zeros_1d', lambda: torch.zeros(5)), + ('eye_3x3', lambda: torch.eye(3)), + ('eye_4x2', lambda: torch.eye(4, 2)), + ('full_tuple', lambda: torch.full((2, 3), 3.141592)), ) def test_tensor_constructor(self, func: Callable[[], torch.Tensor]): expected = func() diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index 10e74ec13f5..73f60aa9149 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -1,32 +1,54 @@ """Tensor constructor overrides""" +import functools import logging +from typing import Callable, Optional, ParamSpec, Sequence import warnings +import jax import torch import jax.numpy as jnp from torch_xla2 import tensor -# TODO: registry -# TODO: correct types -def ones(*size: int): +registry = {} + +P = ParamSpec('P') + +def register_function(torch_func: Callable[P, torch.Tensor]): + def decorator(jax_impl: Callable[P, jax.Array]): + @functools.wraps(torch_func) + def wrapper(*args: P.args, **kwargs: P.kwargs): + return jax_impl(*args, **kwargs) + + registry[torch_func] = jax_impl + + return wrapper + return decorator + + +@register_function(torch.tensor) +def _tensor(data, *args, **kwargs): + return jnp.array(data) + +@register_function(torch.ones) +def _ones(*size: int, **kwargs): return jnp.ones(size) +@register_function(torch.zeros) +def _zeros(*size: int, **kwargs): + return jnp.zeros(size) + +@register_function(torch.eye) +def _eye(n: int, m: Optional[int] = None, **kwargs): + return jnp.eye(n, m) -fns = { - torch.tensor: jnp.array, - torch.ones: ones, - # torch.zeros: jnp.zeros, - # torch.arange: jnp.arange, - # torch.linspace: jnp.linspace, - # torch.logspace: jnp.logspace, - # torch.empty: jnp.empty, - # torch.eye: jnp.eye, - # torch.full: jnp.full, -} +@register_function(torch.full) +def _full(size: Sequence[int], fill_value, **kwargs): + # TODO: handle torch.Size + return jnp.full(size, fill_value) class XLAFunctionMode(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor: - jax_func = fns.get(func) + jax_func = registry.get(func) if not jax_func: logging.warn(f'Falling back to default implementation of {func.__name__}') return func(*args, **(kwargs or {})) @@ -34,4 +56,5 @@ def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor: if kwargs: warnings.warn(f'kwargs not implemented for {kwargs}') - return tensor.wrap(jax_func(*tensor.unwrap(args))) + # TODO: unwrap args here or in implementations? + return tensor.wrap(jax_func(*args))