Skip to content

Commit

Permalink
use correct dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Apr 5, 2024
1 parent d58da14 commit 04376e8
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 22 deletions.
10 changes: 7 additions & 3 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@ class TestTorchFunctions(parameterized.TestCase):
('tensor_1d', lambda: torch.tensor([0, 1],)),
('tensor_scalar', lambda: torch.tensor(3.14159,)),
('tensor_empty', lambda: torch.tensor([],)),
('tensor_dtype', lambda: torch.tensor([[0.11111, 0.222222, 0.3333333]], dtype=torch.float64)),
('ones_2d', lambda: torch.ones(2, 3)),
('ones_1d', lambda: torch.ones(5)),
('ones_1d_dtype', lambda: torch.ones(5, dtype=torch.float16)),
('zeros_2d', lambda: torch.zeros(2, 3)),
('zeros_1d', lambda: torch.zeros(5)),
('zeros_1d_dtype', lambda: torch.zeros(5, dtype=torch.complex64)),
('eye_3x3', lambda: torch.eye(3)),
('eye_4x2', lambda: torch.eye(4, 2)),
('full_tuple', lambda: torch.full((2, 3), 3.141592)),
('eye_4x2_dtype', lambda: torch.eye(4, 2, dtype=torch.float16)),
('full_2d', lambda: torch.full((2, 3), 3.141592)),
('full_2d_dtype', lambda: torch.full((2, 3), 3.141592, dtype=torch.float16)),
)
def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
expected = func()
Expand All @@ -27,8 +32,7 @@ def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
actual = func()
self.assertIsInstance(actual, torch_xla2.tensor.XLATensor2)

# TODO: dtype is actually important
torch.testing.assert_close(torch_xla2.tensor.j2t(actual._elem), expected, check_dtype=False)
torch.testing.assert_close(torch_xla2.tensor.j2t(actual._elem), expected)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch_xla2 import tensor
from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration


jax.config.update('jax_enable_x64', True)

def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
Expand Down
55 changes: 40 additions & 15 deletions experimental/torch_xla2/torch_xla2/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import logging
from typing import Callable, Optional, ParamSpec, Sequence
import warnings

import jax
import torch
Expand All @@ -24,27 +23,56 @@ def wrapper(*args: P.args, **kwargs: P.kwargs):
return wrapper
return decorator

def convert_dtype(use_default_dtype: bool = True):
def decorator(func: Callable[P, torch.Tensor]):
@functools.wraps(func)
def wrapper(*args: P.args, dtype: Optional[torch.dtype] = None, **kwargs: P.kwargs):
if not dtype and use_default_dtype:
dtype = torch.get_default_dtype()
jax_dtype = tensor.t2j_dtype(dtype)

return func(*args, dtype=jax_dtype, **kwargs)

return wrapper

return decorator

@register_function(torch.tensor)
def _tensor(data, *args, **kwargs):
return jnp.array(data)
@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
def _tensor(data, *, dtype=None, **kwargs):
python_types_to_torch_types = {
bool: jnp.bool,
int: jnp.int64,
float: jnp.float32,
complex: jnp.complex64,
}
if not dtype:
leaves = jax.tree_util.tree_leaves(data)
if len(leaves) > 0:
dtype = python_types_to_torch_types.get(type(leaves[0]))

return jnp.array(data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype()))

@register_function(torch.ones)
def _ones(*size: int, **kwargs):
return jnp.ones(size)
@convert_dtype()
def _ones(*size: int, dtype=None, **kwargs):
return jnp.ones(size, dtype)

@register_function(torch.zeros)
def _zeros(*size: int, **kwargs):
return jnp.zeros(size)
@convert_dtype()
def _zeros(*size: int, dtype=None, **kwargs):
return jnp.zeros(size, dtype)

@register_function(torch.eye)
def _eye(n: int, m: Optional[int] = None, **kwargs):
return jnp.eye(n, m)
@convert_dtype()
def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs):
return jnp.eye(n, m, dtype=dtype)

@register_function(torch.full)
def _full(size: Sequence[int], fill_value, **kwargs):
@convert_dtype()
def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
# TODO: handle torch.Size
return jnp.full(size, fill_value)
return jnp.full(size, fill_value, dtype=dtype)

class XLAFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:
Expand All @@ -53,8 +81,5 @@ def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:
logging.warn(f'Falling back to default implementation of {func.__name__}')
return func(*args, **(kwargs or {}))

if kwargs:
warnings.warn(f'kwargs not implemented for {kwargs}')

# TODO: unwrap args here or in implementations?
return tensor.wrap(jax_func(*args))
return tensor.wrap(jax_func(*args, **(kwargs or {})))
7 changes: 4 additions & 3 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __torch_dispatch__(self, fn, types, args=(), kwargs=None):
args, kwargs = unwrap((args, kwargs))
res = constructors[fn](*args, **kwargs)
return wrap(res)

return fn(*args, **kwargs)


Expand Down Expand Up @@ -97,6 +97,7 @@ def j2t(x):

def t2j_dtype(dtype):
return {
torch.float16: jnp.float16,
torch.bfloat16: jnp.bfloat16,
torch.half: jnp.float16,
torch.float32: jnp.float32,
Expand Down Expand Up @@ -192,10 +193,10 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with jax.named_scope(func.name()):
if isinstance(func, torch._ops.OpOverloadPacket):
return func(*args, **kwargs)

if func in jaten.all_ops:
return jaten.all_ops[func](*args, **kwargs)

lowering = ops_registry.lowerings.lookup(func)

if lowering is None:
Expand Down

0 comments on commit 04376e8

Please sign in to comment.