Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training UX: automatic generating make_train_step #8495

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 26 additions & 85 deletions experimental/torch_xla2/examples/basic_training_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
"""

import functools
from torch_xla2 import train, interop
import torch
from torch.utils import _pytree as pytree
import torchvision
Expand All @@ -17,6 +19,8 @@
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

env = torch_xla2.enable_globally()


transform = transforms.Compose(
[transforms.ToTensor(),
Expand All @@ -38,29 +42,7 @@
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
if one_channel:
img = img.mean(dim=0)
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
if one_channel:
plt.imshow(npimg, cmap="Greys")
else:
plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print(' '.join(classes[labels[j]] for j in range(4)))


import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -83,62 +65,55 @@ def forward(self, x):
model = GarmentClassifier()
loss_fn = torch.nn.CrossEntropyLoss()

jax_weights, jax_func = torch_xla2.extract_jax(model)
jax_func = jax.jit(jax_func, inline=True)
jax_optimizer = optax.adam(0.01)
opt_state = jax_optimizer.init(jax_weights)

model.to('jax') # move the model to jax device
model_jittable = interop.JittableModule(model)
weights = model_jittable.params # these are trainable parameters
buffers = model_jittable.buffers # these are non-trainable parameters

def jax_loss(weights, data, label):
pred = jax_func(weights, data)
loss = torch_xla2.interop.call_torch(loss_fn, pred, label)
return loss
opt_state = interop.call_jax(jax_optimizer.init, weights)
model_fn = functools.partial(model_jittable.functional_call, 'forward')

grad_fn = jax.jit(jax.value_and_grad(jax_loss))
train_step = train.make_train_step(model_fn, loss_fn, jax_optimizer)

train_step = interop.jax_jit(train_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)})

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10)
dummy_inputs = torch.rand(4, 28, 28).to('jax')
dummy_outputs = torch.rand(4, 10).to('jax')
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7])

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))

dummy_labels = torch.tensor([1, 5, 3, 7]).to('jax')

def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
# test train_step

def train_one_epoch(weights, buffers, opt_state, epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.

# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
# NEW: Move model to XLA device
data = pytree.tree_map_only(torch.Tensor,
torch_xla2.tensor.t2j, data)
inputs, labels = data

val, grads = grad_fn(jax_weights, (inputs, ), labels)
updates, opt_state = jax_optimizer.update(grads, opt_state)
jax_weights = optax.apply_updates(jax_weights, updates)
inputs = inputs.to('jax')
labels = labels.to('jax')

loss, weights, opt_state = train_step(
weights, buffers, opt_state, inputs, labels)

# Gather data and report
running_loss += val.item()
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.

return last_loss, opt_state
return last_loss, weights, opt_state



Expand All @@ -152,39 +127,5 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))

# Make sure gradient tracking is on, and do a pass over the data
model.train(True)

avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer)

running_vloss = 0.0
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()

# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):

vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata)
voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward
vloss = torch_xla2.interop.call_torch(loss_fn, voutputs, vlabels)
running_vloss += vloss

avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

# Log the running loss averaged per batch
# for both training and validation
writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : np.asarray(avg_loss), 'Validation' : np.asarray(avg_vloss) },
epoch_number + 1)
writer.flush()

# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)

epoch_number += 1
avg_loss, weights, opt_state = train_one_epoch(weights, buffers, opt_state, epoch_number, writer)
print(avg_loss)
5 changes: 5 additions & 0 deletions experimental/torch_xla2/torch_xla2/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,8 @@ def jax_jit(torch_function, kwargs_for_jax_jit=None):
def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
return wrap_jax_jit(torch_function, jax_jit_func=shard_map,
kwargs_for_jax=kwargs_for_jax_shard_map)


def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
return wrap_jax_jit(torch_function, jax_jit_func=jax.value_and_grad,
kwargs_for_jax=kwargs_for_value_and_grad)
64 changes: 64 additions & 0 deletions experimental/torch_xla2/torch_xla2/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import functools
import torch
import jax
import torch_xla2
from torch_xla2 import interop
from torch_xla2.interop import torch_view, jax_view
import optax


remat = torch_view(jax.remat)
mark_sharding = torch_view(jax.lax.with_sharding_constraint)


def make_train_step(model_fn,
loss_fn, optax_optimizer,
remat_policy=None,
mark_fsdp_sharding_axis=None):
"""Make a function that do one train step given model and loss.

model_fn: a function representing the model's forward:
i.e. has signature Callable[weights, buffers, args] -> result. Where,
weights is a pytree of trainable parameters
buffers is a pytree of non-trainable parameters / constants
args is the input data loaded from the data set
result is the return value of the model
loss_fn: a function to compute loss.
i.e. it has signature of Callable[result, label] -> loss
where, result is what model_fn returned
loss is loaded from the dataloader.
optax_optimizer: the optimizer from optax library. for example, optax.adam
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
to do gradient checkpointing. If None, then it means no checkpointing.
mark_fsdp_sharding_axis: str. A string name for marking sharding for
fsdp. It must be an axis that exists in the current mesh.
if None, then no sharding is specified (i.e. for single device)
"""
env = torch_xla2.default_env()
@functools.partial(
remat,
policy=remat_policy)
def loss(weights, buffers, args, label): # inputs are XLATensor
with env, jax.named_scope('compute_loss'):
if mark_fsdp_sharding_axis is not None:
args = (mark_sharding(args[0], P(mark_fsdp_sharding_axis)), *args[1:])
res = model_fn(weights, buffers, args)
if mark_fsdp_sharding_axis is not None:
res = mark_sharding(res, P(mark_fsdp_sharding_axis))
l = loss_fn(res, label)
return l

grad_fn = interop.jax_value_and_grad(loss)

def step(weights, buffers, opt_state, args, label): #inputs are array
with jax.named_scope('compute_gradient'):
loss, gradient = grad_fn(weights, buffers, args, label)

with jax.named_scope("optimizer_updates"):
updates, opt_state = interop.call_jax(
optax_optimizer.update,
gradient, opt_state, weights)
weights = interop.call_jax(optax.apply_updates, weights, updates)
return loss, weights, opt_state

return step
Loading