diff --git a/experimental/torch_xla2/examples/lightning_training.py b/experimental/torch_xla2/examples/lightning_training.py new file mode 100644 index 00000000000..b09f00d9473 --- /dev/null +++ b/experimental/torch_xla2/examples/lightning_training.py @@ -0,0 +1,77 @@ +import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv +import lightning as L + +encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) +decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) + +class LitAutoEncoder(L.LightningModule): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder, self.decoder = encoder, decoder + + def training_step(self, batch, batch_idx): + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = nn.functional.mse_loss(x_hat, x) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-3) + +dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) + +# Lightning will automatically use all available GPUs! +trainer = L.Trainer() +# trainer.fit(LitAutoEncoder(encoder, decoder), data.DataLoader(dataset, batch_size=64)) + +# ==== above is the lightning example from +# https://lightning.ai/pytorch-lightning + +import torch_xla2 +from torch_xla2.interop import jax_view, torch_view +import jax +import optax + +class JaxTrainer: + + def __init__(self): + pass + + def torch_opt_to_jax_opt(self, torch_opt): + # TODO: Can convert optimizer instead of using a jax one + return optax.adam(0.001) + + def fit(self, lightning_mod, data_loader): + + xla_env = torch_xla2.default_env() + + def lightning_mod_loss( + weights: jax.Array, data: jax.Array, batch_id): + """returns loss""" + weights, data = torch_view((weights, data)) + lightning_mod.load_state_dict(weights, assign=True) + with xla_env: + loss = lightning_mod.training_step(data, batch_id) + return jax_view(loss) + + jax_weights = jax_view(xla_env.to_xla(lightning_mod.state_dict())) + jax_optimizer = self.torch_opt_to_jax_opt( + lightning_mod.configure_optimizers()) + opt_state = jax_optimizer.init(jax_weights) + grad_fn = jax.jit(jax.value_and_grad(lightning_mod_loss)) + + for bid in range(3): + for item in data_loader: + xla_data = jax_view(xla_env.to_xla(item)) + loss, grads = grad_fn(jax_weights, xla_data, bid) + updates, opt_state = jax_optimizer.update(grads, opt_state) + jax_weights = optax.apply_updates(jax_weights, updates) + print('current_loss', loss) + + +print('-----------------') +trainer_jax = JaxTrainer() +trainer_jax.fit(LitAutoEncoder(encoder, decoder), data.DataLoader(dataset, batch_size=64)) diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/experimental/torch_xla2/torch_xla2/interop.py index fbcd47922e1..d1a96179e82 100644 --- a/experimental/torch_xla2/torch_xla2/interop.py +++ b/experimental/torch_xla2/torch_xla2/interop.py @@ -10,8 +10,7 @@ - -def torch_view(t: JaxValue) -> TorchValue: +def _torch_view(t: JaxValue) -> TorchValue: # t is an object from jax land # view it as-if it's a torch land object if isinstance(t, jax.Array): @@ -24,8 +23,10 @@ def torch_view(t: JaxValue) -> TorchValue: # regular types are not changed return t +torch_view = functools.partial(pytree.tree_map, _torch_view) + -def jax_view(t: TorchValue) -> JaxValue: +def _jax_view(t: TorchValue) -> JaxValue: # t is an object from torch land # view it as-if it's a jax land object if isinstance(t, torch.Tensor): @@ -40,17 +41,19 @@ def jax_view(t: TorchValue) -> JaxValue: # regular types are not changed return t +jax_view = functools.partial(pytree.tree_map, _jax_view) + def call_jax(jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue) -> TorchValue: - args, kwargs = pytree.tree_map(jax_view, (args, kwargs)) + args, kwargs = jax_view((args, kwargs)) res: JaxValue = jax_func(*args, **kwargs) return torch_view(res) def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue: - args, kwargs = pytree.tree_map(torch_view, (args, kwargs)) + args, kwargs = torch_view((args, kwargs)) with torch_xla2.default_env(): res: TorchValue = torch_func(*args, **kwargs) return jax_view(res) @@ -63,3 +66,4 @@ def jax_jit(torch_function, kwargs_for_jax_jit=None): jax_func = jax_view(torch_function) jitted = jax.jit(jax_func, **kwargs_for_jax_jit) return torch_view(jitted) + diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index f6adc702a14..3c92e5e290d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1827,6 +1827,7 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0): @op(torch.ops.aten.randn, needs_env=True) +@op_base.convert_dtype() def _randn( *size, generator=None, @@ -1844,12 +1845,12 @@ def _randn( key = env.get_and_rotate_prng_key() res = jax.random.normal(key, shape) if dtype is not None: - dtype = tensor.t2j_dtype(dtype) res = res.astype(dtype) return res @op(torch.ops.aten.rand, needs_env=True) +@op_base.convert_dtype() def _rand( *size, generator=None, @@ -1867,7 +1868,6 @@ def _rand( key = env.get_and_rotate_prng_key() res = jax.random.uniform(key, shape) if dtype is not None: - dtype = tensor.t2j_dtype(dtype) res = res.astype(dtype) return res diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index ddc04fa4b1b..6d7003b936e 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -1,49 +1,21 @@ """Tensor constructor overrides""" import functools -from typing import Callable, Optional, ParamSpec, Sequence +from typing import Optional, Sequence import jax import torch import jax.numpy as jnp from torch_xla2 import tensor from torch_xla2.ops.ops_registry import register_torch_function_op +from torch_xla2.ops import op_base + def register_function(torch_func, **kwargs): return functools.partial(register_torch_function_op, torch_func, **kwargs) -P = ParamSpec('P') - - -def convert_dtype(use_default_dtype: bool = True): - """Converts `dtype` kwarg of function from torch to JAX. - - Args: - use_default_dtype: Whether to use torch default dtype if none is provided. - - Returns: - A decorator that wraps a JAX implementation of a torch function. - """ - - def decorator(func: Callable[P, torch.Tensor]): - - @functools.wraps(func) - def wrapper(*args: P.args, - dtype: Optional[torch.dtype] = None, - **kwargs: P.kwargs): - if not dtype and use_default_dtype: - dtype = torch.get_default_dtype() - jax_dtype = tensor.t2j_dtype(dtype) - - return func(*args, dtype=jax_dtype, **kwargs) - - return wrapper - - return decorator - - @register_function(torch.tensor) -@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements +@op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements def _tensor(data, *, dtype=None, **kwargs): python_types_to_torch_types = { bool: jnp.bool, @@ -61,25 +33,25 @@ def _tensor(data, *, dtype=None, **kwargs): @register_function(torch.ones) -@convert_dtype() +@op_base.convert_dtype() def _ones(*size: int, dtype=None, **kwargs): return jnp.ones(size, dtype) @register_function(torch.zeros) -@convert_dtype() +@op_base.convert_dtype() def _zeros(*size: int, dtype=None, **kwargs): return jnp.zeros(size, dtype) @register_function(torch.eye) -@convert_dtype() +@op_base.convert_dtype() def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs): return jnp.eye(n, m, dtype=dtype) @register_function(torch.full) -@convert_dtype() +@op_base.convert_dtype() def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs): # TODO: handle torch.Size return jnp.full(size, fill_value, dtype=dtype) diff --git a/experimental/torch_xla2/torch_xla2/ops/op_base.py b/experimental/torch_xla2/torch_xla2/ops/op_base.py index 983d20fb660..8076a11fb09 100644 --- a/experimental/torch_xla2/torch_xla2/ops/op_base.py +++ b/experimental/torch_xla2/torch_xla2/ops/op_base.py @@ -1,5 +1,9 @@ +import functools import torch -from torch_xla2 import interop +from torch_xla2 import interop, tensor +from torch_xla2 import types + +from typing import Callable, Optional, ParamSpec, Sequence class BinaryOpWithPromotion: @@ -52,4 +56,29 @@ def __call__(self, *args, **kwargs): +P = ParamSpec('P') +def convert_dtype(use_default_dtype: bool = True): + """Converts `dtype` kwarg of function from torch to JAX. + + Args: + use_default_dtype: Whether to use torch default dtype if none is provided. + + Returns: + A decorator that wraps a JAX implementation of a torch function. + """ + + def decorator(func: types.TorchCallable): + + @functools.wraps(func) + def wrapper(*args: P.args, + dtype: Optional[torch.dtype] = None, + **kwargs: P.kwargs): + if not dtype and use_default_dtype: + dtype = torch.get_default_dtype() + jax_dtype = tensor.t2j_dtype(dtype) + + return func(*args, dtype=jax_dtype, **kwargs) + + return wrapper + return decorator