Skip to content

Commit

Permalink
Add example to support pytorch lightning; misc bug fixes (#7054)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored May 14, 2024
1 parent 2b28ae2 commit b64d8a2
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 44 deletions.
77 changes: 77 additions & 0 deletions experimental/torch_xla2/examples/lightning_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv
import lightning as L

encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder, self.decoder = encoder, decoder

def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)

dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())

# Lightning will automatically use all available GPUs!
trainer = L.Trainer()
# trainer.fit(LitAutoEncoder(encoder, decoder), data.DataLoader(dataset, batch_size=64))

# ==== above is the lightning example from
# https://lightning.ai/pytorch-lightning

import torch_xla2
from torch_xla2.interop import jax_view, torch_view
import jax
import optax

class JaxTrainer:

def __init__(self):
pass

def torch_opt_to_jax_opt(self, torch_opt):
# TODO: Can convert optimizer instead of using a jax one
return optax.adam(0.001)

def fit(self, lightning_mod, data_loader):

xla_env = torch_xla2.default_env()

def lightning_mod_loss(
weights: jax.Array, data: jax.Array, batch_id):
"""returns loss"""
weights, data = torch_view((weights, data))
lightning_mod.load_state_dict(weights, assign=True)
with xla_env:
loss = lightning_mod.training_step(data, batch_id)
return jax_view(loss)

jax_weights = jax_view(xla_env.to_xla(lightning_mod.state_dict()))
jax_optimizer = self.torch_opt_to_jax_opt(
lightning_mod.configure_optimizers())
opt_state = jax_optimizer.init(jax_weights)
grad_fn = jax.jit(jax.value_and_grad(lightning_mod_loss))

for bid in range(3):
for item in data_loader:
xla_data = jax_view(xla_env.to_xla(item))
loss, grads = grad_fn(jax_weights, xla_data, bid)
updates, opt_state = jax_optimizer.update(grads, opt_state)
jax_weights = optax.apply_updates(jax_weights, updates)
print('current_loss', loss)


print('-----------------')
trainer_jax = JaxTrainer()
trainer_jax.fit(LitAutoEncoder(encoder, decoder), data.DataLoader(dataset, batch_size=64))
14 changes: 9 additions & 5 deletions experimental/torch_xla2/torch_xla2/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@




def torch_view(t: JaxValue) -> TorchValue:
def _torch_view(t: JaxValue) -> TorchValue:
# t is an object from jax land
# view it as-if it's a torch land object
if isinstance(t, jax.Array):
Expand All @@ -24,8 +23,10 @@ def torch_view(t: JaxValue) -> TorchValue:
# regular types are not changed
return t

torch_view = functools.partial(pytree.tree_map, _torch_view)


def jax_view(t: TorchValue) -> JaxValue:
def _jax_view(t: TorchValue) -> JaxValue:
# t is an object from torch land
# view it as-if it's a jax land object
if isinstance(t, torch.Tensor):
Expand All @@ -40,17 +41,19 @@ def jax_view(t: TorchValue) -> JaxValue:
# regular types are not changed
return t

jax_view = functools.partial(pytree.tree_map, _jax_view)


def call_jax(jax_func: JaxCallable,
*args: TorchValue,
**kwargs: TorchValue) -> TorchValue:
args, kwargs = pytree.tree_map(jax_view, (args, kwargs))
args, kwargs = jax_view((args, kwargs))
res: JaxValue = jax_func(*args, **kwargs)
return torch_view(res)


def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue:
args, kwargs = pytree.tree_map(torch_view, (args, kwargs))
args, kwargs = torch_view((args, kwargs))
with torch_xla2.default_env():
res: TorchValue = torch_func(*args, **kwargs)
return jax_view(res)
Expand All @@ -63,3 +66,4 @@ def jax_jit(torch_function, kwargs_for_jax_jit=None):
jax_func = jax_view(torch_function)
jitted = jax.jit(jax_func, **kwargs_for_jax_jit)
return torch_view(jitted)

4 changes: 2 additions & 2 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,7 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0):


@op(torch.ops.aten.randn, needs_env=True)
@op_base.convert_dtype()
def _randn(
*size,
generator=None,
Expand All @@ -1844,12 +1845,12 @@ def _randn(
key = env.get_and_rotate_prng_key()
res = jax.random.normal(key, shape)
if dtype is not None:
dtype = tensor.t2j_dtype(dtype)
res = res.astype(dtype)
return res


@op(torch.ops.aten.rand, needs_env=True)
@op_base.convert_dtype()
def _rand(
*size,
generator=None,
Expand All @@ -1867,7 +1868,6 @@ def _rand(
key = env.get_and_rotate_prng_key()
res = jax.random.uniform(key, shape)
if dtype is not None:
dtype = tensor.t2j_dtype(dtype)
res = res.astype(dtype)
return res

Expand Down
44 changes: 8 additions & 36 deletions experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,21 @@
"""Tensor constructor overrides"""
import functools
from typing import Callable, Optional, ParamSpec, Sequence
from typing import Optional, Sequence

import jax
import torch
import jax.numpy as jnp
from torch_xla2 import tensor
from torch_xla2.ops.ops_registry import register_torch_function_op
from torch_xla2.ops import op_base


def register_function(torch_func, **kwargs):
return functools.partial(register_torch_function_op, torch_func, **kwargs)


P = ParamSpec('P')


def convert_dtype(use_default_dtype: bool = True):
"""Converts `dtype` kwarg of function from torch to JAX.
Args:
use_default_dtype: Whether to use torch default dtype if none is provided.
Returns:
A decorator that wraps a JAX implementation of a torch function.
"""

def decorator(func: Callable[P, torch.Tensor]):

@functools.wraps(func)
def wrapper(*args: P.args,
dtype: Optional[torch.dtype] = None,
**kwargs: P.kwargs):
if not dtype and use_default_dtype:
dtype = torch.get_default_dtype()
jax_dtype = tensor.t2j_dtype(dtype)

return func(*args, dtype=jax_dtype, **kwargs)

return wrapper

return decorator


@register_function(torch.tensor)
@convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
@op_base.convert_dtype(use_default_dtype=False) # Attempt to infer type from elements
def _tensor(data, *, dtype=None, **kwargs):
python_types_to_torch_types = {
bool: jnp.bool,
Expand All @@ -61,25 +33,25 @@ def _tensor(data, *, dtype=None, **kwargs):


@register_function(torch.ones)
@convert_dtype()
@op_base.convert_dtype()
def _ones(*size: int, dtype=None, **kwargs):
return jnp.ones(size, dtype)


@register_function(torch.zeros)
@convert_dtype()
@op_base.convert_dtype()
def _zeros(*size: int, dtype=None, **kwargs):
return jnp.zeros(size, dtype)


@register_function(torch.eye)
@convert_dtype()
@op_base.convert_dtype()
def _eye(n: int, m: Optional[int] = None, *, dtype=None, **kwargs):
return jnp.eye(n, m, dtype=dtype)


@register_function(torch.full)
@convert_dtype()
@op_base.convert_dtype()
def _full(size: Sequence[int], fill_value, *, dtype=None, **kwargs):
# TODO: handle torch.Size
return jnp.full(size, fill_value, dtype=dtype)
Expand Down
31 changes: 30 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/op_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import functools
import torch
from torch_xla2 import interop
from torch_xla2 import interop, tensor
from torch_xla2 import types

from typing import Callable, Optional, ParamSpec, Sequence


class BinaryOpWithPromotion:
Expand Down Expand Up @@ -52,4 +56,29 @@ def __call__(self, *args, **kwargs):



P = ParamSpec('P')
def convert_dtype(use_default_dtype: bool = True):
"""Converts `dtype` kwarg of function from torch to JAX.
Args:
use_default_dtype: Whether to use torch default dtype if none is provided.
Returns:
A decorator that wraps a JAX implementation of a torch function.
"""

def decorator(func: types.TorchCallable):

@functools.wraps(func)
def wrapper(*args: P.args,
dtype: Optional[torch.dtype] = None,
**kwargs: P.kwargs):
if not dtype and use_default_dtype:
dtype = torch.get_default_dtype()
jax_dtype = tensor.t2j_dtype(dtype)

return func(*args, dtype=jax_dtype, **kwargs)

return wrapper

return decorator

0 comments on commit b64d8a2

Please sign in to comment.