diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index 05be27878c4..2b372695e65 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -14,6 +14,7 @@ def register_function(torch_func: Callable[P, torch.Tensor]): + """Registers a function as the JAX implementation of a torch function.""" def decorator(jax_impl: Callable[P, jax.Array]): registry[torch_func] = jax_impl @@ -23,6 +24,14 @@ def decorator(jax_impl: Callable[P, jax.Array]): def convert_dtype(use_default_dtype: bool = True): + """Converts `dtype` kwarg of function from torch to JAX. + + Args: + use_default_dtype: Whether to use torch default dtype if none is provided. + + Returns: + A decorator that wraps a JAX implementation of a torch function. + """ def decorator(func: Callable[P, torch.Tensor]): @@ -85,6 +94,7 @@ def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): class XLAFunctionMode(torch.overrides.TorchFunctionMode): + """Context manager that dispatches torch function calls to JAX.""" def __torch_function__(self, func,