-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
99edc32
commit d58da14
Showing
2 changed files
with
44 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,60 @@ | ||
"""Tensor constructor overrides""" | ||
import functools | ||
import logging | ||
from typing import Callable, Optional, ParamSpec, Sequence | ||
import warnings | ||
|
||
import jax | ||
import torch | ||
import jax.numpy as jnp | ||
from torch_xla2 import tensor | ||
|
||
# TODO: registry | ||
# TODO: correct types | ||
def ones(*size: int): | ||
registry = {} | ||
|
||
P = ParamSpec('P') | ||
|
||
def register_function(torch_func: Callable[P, torch.Tensor]): | ||
def decorator(jax_impl: Callable[P, jax.Array]): | ||
@functools.wraps(torch_func) | ||
def wrapper(*args: P.args, **kwargs: P.kwargs): | ||
return jax_impl(*args, **kwargs) | ||
|
||
registry[torch_func] = jax_impl | ||
|
||
return wrapper | ||
return decorator | ||
|
||
|
||
@register_function(torch.tensor) | ||
def _tensor(data, *args, **kwargs): | ||
return jnp.array(data) | ||
|
||
@register_function(torch.ones) | ||
def _ones(*size: int, **kwargs): | ||
return jnp.ones(size) | ||
|
||
@register_function(torch.zeros) | ||
def _zeros(*size: int, **kwargs): | ||
return jnp.zeros(size) | ||
|
||
@register_function(torch.eye) | ||
def _eye(n: int, m: Optional[int] = None, **kwargs): | ||
return jnp.eye(n, m) | ||
|
||
fns = { | ||
torch.tensor: jnp.array, | ||
torch.ones: ones, | ||
# torch.zeros: jnp.zeros, | ||
# torch.arange: jnp.arange, | ||
# torch.linspace: jnp.linspace, | ||
# torch.logspace: jnp.logspace, | ||
# torch.empty: jnp.empty, | ||
# torch.eye: jnp.eye, | ||
# torch.full: jnp.full, | ||
} | ||
@register_function(torch.full) | ||
def _full(size: Sequence[int], fill_value, **kwargs): | ||
# TODO: handle torch.Size | ||
return jnp.full(size, fill_value) | ||
|
||
class XLAFunctionMode(torch.overrides.TorchFunctionMode): | ||
def __torch_function__(self, func, types, args=(), kwargs=None) -> torch.Tensor: | ||
jax_func = fns.get(func) | ||
jax_func = registry.get(func) | ||
if not jax_func: | ||
logging.warn(f'Falling back to default implementation of {func.__name__}') | ||
return func(*args, **(kwargs or {})) | ||
|
||
if kwargs: | ||
warnings.warn(f'kwargs not implemented for {kwargs}') | ||
|
||
return tensor.wrap(jax_func(*tensor.unwrap(args))) | ||
# TODO: unwrap args here or in implementations? | ||
return tensor.wrap(jax_func(*args)) |