diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index 8232ee74fa7..76e842d6fdd 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -6,24 +6,28 @@ import torch_xla2.functions import torch_xla2.tensor + class TestTorchFunctions(parameterized.TestCase): + @parameterized.named_parameters( - ('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])), - ('tensor_1d', lambda: torch.tensor([0, 1],)), - ('tensor_scalar', lambda: torch.tensor(3.14159,)), - ('tensor_empty', lambda: torch.tensor([],)), - ('tensor_dtype', lambda: torch.tensor([[0.11111, 0.222222, 0.3333333]], dtype=torch.float64)), - ('ones_2d', lambda: torch.ones(2, 3)), - ('ones_1d', lambda: torch.ones(5)), - ('ones_1d_dtype', lambda: torch.ones(5, dtype=torch.float16)), - ('zeros_2d', lambda: torch.zeros(2, 3)), - ('zeros_1d', lambda: torch.zeros(5)), - ('zeros_1d_dtype', lambda: torch.zeros(5, dtype=torch.complex64)), - ('eye_3x3', lambda: torch.eye(3)), - ('eye_4x2', lambda: torch.eye(4, 2)), - ('eye_4x2_dtype', lambda: torch.eye(4, 2, dtype=torch.float16)), - ('full_2d', lambda: torch.full((2, 3), 3.141592)), - ('full_2d_dtype', lambda: torch.full((2, 3), 3.141592, dtype=torch.float16)), + ('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])), + ('tensor_1d', lambda: torch.tensor([0, 1],)), + ('tensor_scalar', lambda: torch.tensor(3.14159,)), + ('tensor_empty', lambda: torch.tensor([],)), + ('tensor_dtype', lambda: torch.tensor([[0.11111, 0.222222, 0.3333333]], + dtype=torch.float64)), + ('ones_2d', lambda: torch.ones(2, 3)), + ('ones_1d', lambda: torch.ones(5)), + ('ones_1d_dtype', lambda: torch.ones(5, dtype=torch.float16)), + ('zeros_2d', lambda: torch.zeros(2, 3)), + ('zeros_1d', lambda: torch.zeros(5)), + ('zeros_1d_dtype', lambda: torch.zeros(5, dtype=torch.complex64)), + ('eye_3x3', lambda: torch.eye(3)), + ('eye_4x2', lambda: torch.eye(4, 2)), + ('eye_4x2_dtype', lambda: torch.eye(4, 2, dtype=torch.float16)), + ('full_2d', lambda: torch.full((2, 3), 3.141592)), + ('full_2d_dtype', lambda: torch.full( + (2, 3), 3.141592, dtype=torch.float16)), ) def test_tensor_constructor(self, func: Callable[[], torch.Tensor]): expected = func() diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index f0fb68b3181..4d07006fcd0 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -7,6 +7,7 @@ jax.config.update('jax_enable_x64', True) + def extract_jax(mod: torch.nn.Module): """Returns a pytree of jax.ndarray and a jax callable.""" func, weights, buffer = make_functional.make_functional_with_buffers(mod) diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index bab8974c207..d1710405ccc 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -12,8 +12,11 @@ P = ParamSpec('P') + def register_function(torch_func: Callable[P, torch.Tensor]): + def decorator(jax_impl: Callable[P, jax.Array]): + @functools.wraps(torch_func) def wrapper(*args: P.args, **kwargs: P.kwargs): return jax_impl(*args, **kwargs) @@ -21,12 +24,18 @@ def wrapper(*args: P.args, **kwargs: P.kwargs): registry[torch_func] = jax_impl return wrapper + return decorator + def convert_dtype(use_default_dtype: bool = True): + def decorator(func: Callable[P, torch.Tensor]): + @functools.wraps(func) - def wrapper(*args: P.args, dtype: Optional[torch.dtype] = None, **kwargs: P.kwargs): + def wrapper(*args: P.args, + dtype: Optional[torch.dtype] = None, + **kwargs: P.kwargs): if not dtype and use_default_dtype: dtype = torch.get_default_dtype() jax_dtype = tensor.t2j_dtype(dtype) @@ -37,45 +46,57 @@ def wrapper(*args: P.args, dtype: Optional[torch.dtype] = None, **kwargs: P.kwar return decorator + @register_function(torch.tensor) -@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements +@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements def _tensor(data, *, dtype=None, **kwargs): python_types_to_torch_types = { - bool: jnp.bool, - int: jnp.int64, - float: jnp.float32, - complex: jnp.complex64, + bool: jnp.bool, + int: jnp.int64, + float: jnp.float32, + complex: jnp.complex64, } if not dtype: leaves = jax.tree_util.tree_leaves(data) if len(leaves) > 0: dtype = python_types_to_torch_types.get(type(leaves[0])) - return jnp.array(data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype())) + return jnp.array( + data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype())) + @register_function(torch.ones) @convert_dtype() def _ones(*size: int, dtype=None, **kwargs): return jnp.ones(size, dtype) + @register_function(torch.zeros) @convert_dtype() def _zeros(*size: int, dtype=None, **kwargs): return jnp.zeros(size, dtype) + @register_function(torch.eye) @convert_dtype() def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): return jnp.eye(n, m, dtype=dtype) + @register_function(torch.full) @convert_dtype() def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): # TODO: handle torch.Size return jnp.full(size, fill_value, dtype=dtype) + class XLAFunctionMode(torch.overrides.TorchFunctionMode): - def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor: + + def __torch_function__(self, + func, + types, + args=(), + kwargs=None) -> torch.Tensor: jax_func = registry.get(func) if not jax_func: logging.warn(f'Falling back to default implementation of {func.__name__}')