Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Apr 5, 2024
1 parent 04376e8 commit 40ae313
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 24 deletions.
36 changes: 20 additions & 16 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,28 @@
import torch_xla2.functions
import torch_xla2.tensor


class TestTorchFunctions(parameterized.TestCase):

@parameterized.named_parameters(
('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])),
('tensor_1d', lambda: torch.tensor([0, 1],)),
('tensor_scalar', lambda: torch.tensor(3.14159,)),
('tensor_empty', lambda: torch.tensor([],)),
('tensor_dtype', lambda: torch.tensor([[0.11111, 0.222222, 0.3333333]], dtype=torch.float64)),
('ones_2d', lambda: torch.ones(2, 3)),
('ones_1d', lambda: torch.ones(5)),
('ones_1d_dtype', lambda: torch.ones(5, dtype=torch.float16)),
('zeros_2d', lambda: torch.zeros(2, 3)),
('zeros_1d', lambda: torch.zeros(5)),
('zeros_1d_dtype', lambda: torch.zeros(5, dtype=torch.complex64)),
('eye_3x3', lambda: torch.eye(3)),
('eye_4x2', lambda: torch.eye(4, 2)),
('eye_4x2_dtype', lambda: torch.eye(4, 2, dtype=torch.float16)),
('full_2d', lambda: torch.full((2, 3), 3.141592)),
('full_2d_dtype', lambda: torch.full((2, 3), 3.141592, dtype=torch.float16)),
('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])),
('tensor_1d', lambda: torch.tensor([0, 1],)),
('tensor_scalar', lambda: torch.tensor(3.14159,)),
('tensor_empty', lambda: torch.tensor([],)),
('tensor_dtype', lambda: torch.tensor([[0.11111, 0.222222, 0.3333333]],
dtype=torch.float64)),
('ones_2d', lambda: torch.ones(2, 3)),
('ones_1d', lambda: torch.ones(5)),
('ones_1d_dtype', lambda: torch.ones(5, dtype=torch.float16)),
('zeros_2d', lambda: torch.zeros(2, 3)),
('zeros_1d', lambda: torch.zeros(5)),
('zeros_1d_dtype', lambda: torch.zeros(5, dtype=torch.complex64)),
('eye_3x3', lambda: torch.eye(3)),
('eye_4x2', lambda: torch.eye(4, 2)),
('eye_4x2_dtype', lambda: torch.eye(4, 2, dtype=torch.float16)),
('full_2d', lambda: torch.full((2, 3), 3.141592)),
('full_2d_dtype', lambda: torch.full(
(2, 3), 3.141592, dtype=torch.float16)),
)
def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
expected = func()
Expand Down
1 change: 1 addition & 0 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

jax.config.update('jax_enable_x64', True)


def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
Expand Down
37 changes: 29 additions & 8 deletions experimental/torch_xla2/torch_xla2/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@

P = ParamSpec('P')


def register_function(torch_func: Callable[P, torch.Tensor]):

def decorator(jax_impl: Callable[P, jax.Array]):

@functools.wraps(torch_func)
def wrapper(*args: P.args, **kwargs: P.kwargs):
return jax_impl(*args, **kwargs)

registry[torch_func] = jax_impl

return wrapper

return decorator


def convert_dtype(use_default_dtype: bool = True):

def decorator(func: Callable[P, torch.Tensor]):

@functools.wraps(func)
def wrapper(*args: P.args, dtype: Optional[torch.dtype] = None, **kwargs: P.kwargs):
def wrapper(*args: P.args,
dtype: Optional[torch.dtype] = None,
**kwargs: P.kwargs):
if not dtype and use_default_dtype:
dtype = torch.get_default_dtype()
jax_dtype = tensor.t2j_dtype(dtype)
Expand All @@ -37,45 +46,57 @@ def wrapper(*args: P.args, dtype: Optional[torch.dtype] = None, **kwargs: P.kwar

return decorator


@register_function(torch.tensor)
@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
def _tensor(data, *, dtype=None, **kwargs):
python_types_to_torch_types = {
bool: jnp.bool,
int: jnp.int64,
float: jnp.float32,
complex: jnp.complex64,
bool: jnp.bool,
int: jnp.int64,
float: jnp.float32,
complex: jnp.complex64,
}
if not dtype:
leaves = jax.tree_util.tree_leaves(data)
if len(leaves) > 0:
dtype = python_types_to_torch_types.get(type(leaves[0]))

return jnp.array(data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype()))
return jnp.array(
data, dtype=dtype or tensor.t2j_dtype(torch.get_default_dtype()))


@register_function(torch.ones)
@convert_dtype()
def _ones(*size: int, dtype=None, **kwargs):
return jnp.ones(size, dtype)


@register_function(torch.zeros)
@convert_dtype()
def _zeros(*size: int, dtype=None, **kwargs):
return jnp.zeros(size, dtype)


@register_function(torch.eye)
@convert_dtype()
def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs):
return jnp.eye(n, m, dtype=dtype)


@register_function(torch.full)
@convert_dtype()
def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
# TODO: handle torch.Size
return jnp.full(size, fill_value, dtype=dtype)


class XLAFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor:

def __torch_function__(self,
func,
types,
args=(),
kwargs=None) -> torch.Tensor:
jax_func = registry.get(func)
if not jax_func:
logging.warn(f'Falling back to default implementation of {func.__name__}')
Expand Down

0 comments on commit 40ae313

Please sign in to comment.