From 781ee938fad994f91cd973fb7dcde17d2655d6b2 Mon Sep 17 00:00:00 2001 From: qihqi Date: Tue, 16 Apr 2024 18:24:39 -0700 Subject: [PATCH] Add examples for training (#6929) --- experimental/torch_xla2/README.md | 96 ++++++++- experimental/torch_xla2/examples/README.md | 115 ++++++++++ .../torch_xla2/examples/_diffusion.py | 112 ++++++++++ .../torch_xla2/examples/basic_training.py | 197 ++++++++++++++++++ .../torch_xla2/examples/basic_training_jax.py | 196 +++++++++++++++++ .../torch_xla2/examples/eager_mode.py | 42 ++++ .../torch_xla2/examples/requirements.txt | 3 + .../torch_xla2/test/llama/test_llama.py | 1 - experimental/torch_xla2/torch_xla2/_ops.py | 59 ++++++ experimental/torch_xla2/torch_xla2/tensor.py | 1 - 10 files changed, 819 insertions(+), 3 deletions(-) create mode 100644 experimental/torch_xla2/examples/README.md create mode 100644 experimental/torch_xla2/examples/_diffusion.py create mode 100644 experimental/torch_xla2/examples/basic_training.py create mode 100644 experimental/torch_xla2/examples/basic_training_jax.py create mode 100644 experimental/torch_xla2/examples/eager_mode.py create mode 100644 experimental/torch_xla2/examples/requirements.txt diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md index 3b3cf5cc23c..f30be7ff1da 100644 --- a/experimental/torch_xla2/README.md +++ b/experimental/torch_xla2/README.md @@ -71,4 +71,98 @@ pip install -e . ```bash pip install -r test_requirements.txt pytest test -``` \ No newline at end of file +``` + + +## Run a model + +Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model +it can be in theory any instance of `torch.nn.Module`. + +```python + +import torch_xla2 +from torch import nn + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + +m = MyModel() + +# Execute this model using torch +inputs = (torch.randn(3, 3, 28, 28), ) +print(m(*inputs)) +``` + +This model `m` contains 2 parts: the weights that is stored inside of the model +and it's submodules (`nn.Linear`). + +To execute this model with `torch_xla2`; we need to move the tensors involved in compute +to `XLA` devices. This can be accomplished with `torch_xla2.tensor.move_to_device`. + +We need move both the weights and the input to xla devices: + +```python +from torch.utils import _pytree as pytree +from torch_xla2.tensor import move_to_device + +inputs = move_to_device(inputs) +new_state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, m.state_dict()) +m.load_state_dict(new_state_dict, assign=True) + +res = m(*inputs) + +print(type(res)) # outputs XLATensor2 +``` + +### Executing with jax.jit + +The above script will execute the model using eager mode Jax as backend. This +does allow executing torch models on TPU, but is often slower than what we can +achieve with `jax.jit`. + +`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array +and returns jax array) into the same function, but faster. + +We have made the `jax_jit` decorator that would accomplish the same with functions +that takes and returns `torch.Tensor`. To use this, the first step is to create +a functional version of this model: this means the parameters should be passed in +as input instead of being attributes on class: + + +```python + +def model_func(param, inputs): + return torch.func.functional_call(m, param, inputs) + +``` +Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html) +from PyTorch to replace the model +weights with `param`, then call the model. This is equivalent to: + +```python +def model_func(param, inputs): + m.load_state_dict(param) + return m(*inputs) +``` + +Now, we can apply `jax_jit` + +```python +from torch_xla2.extra import jax_jit +model_func_jitted = jax_jit(model_func) +print(model_func_jitted(new_state_dict, inputs)) +``` + + diff --git a/experimental/torch_xla2/examples/README.md b/experimental/torch_xla2/examples/README.md new file mode 100644 index 00000000000..0e22d28c531 --- /dev/null +++ b/experimental/torch_xla2/examples/README.md @@ -0,0 +1,115 @@ +## Intro + +This readme will have a subsection for every example *.py file. + +Please follow the instructions in [README.md](../README.md) to install torch_xla2, +then install requirements for all of the examples with + +```bash +pip install -r requirements.txt +``` + + + +## basic_training.py + +This file constructed by first copy & paste code fragments from this pytorch training tutorial: +https://pytorch.org/tutorials/beginner/introyt/trainingyt.html + +Then adding few lines of code that serves the purpose of moving `torch.Tensor` into +`XLA devices`. + +Example: + +```python +state_dict = pytree.tree_map_only(torch.Tensor, + torch_xla2.tensor.move_to_device, state_dict) +``` + +This fragment moves the state_dict to XLA devices; then the state_dict is passed +back to model via `load_state_dict`. + +Then, you can train the model. This shows what is minimum to train a model on XLA +devices. The perf is not as good because we didn't use `jax.jit`, this is intentional +as it is meant to showcase the minimum code change. + +Example run: +```bash +(xla2) hanq-macbookpro:examples hanq$ python basic_training.py +Training set has 60000 instances +Validation set has 10000 instances +Bag Dress Sneaker T-shirt/top +tensor([[0.8820, 0.3807, 0.3010, 0.9266, 0.7253, 0.9265, 0.0688, 0.4567, 0.7035, + 0.2279], + [0.3253, 0.1558, 0.1274, 0.2776, 0.2590, 0.4169, 0.1881, 0.7423, 0.4561, + 0.5985], + [0.5067, 0.4514, 0.9758, 0.6088, 0.7438, 0.6811, 0.9609, 0.3572, 0.4504, + 0.8738], + [0.1850, 0.1217, 0.8551, 0.2120, 0.9902, 0.7623, 0.1658, 0.6980, 0.3086, + 0.5709]]) +tensor([1, 5, 3, 7]) +Total loss for this batch: 2.325265645980835 +EPOCH 1: + batch 1000 loss: 1.041275198560208 + batch 2000 loss: 0.6450189483696595 + batch 3000 loss: 0.5793989677671343 + batch 4000 loss: 0.5170258888280951 + batch 5000 loss: 0.4920090722264722 + batch 6000 loss: 0.48910293977567926 + batch 7000 loss: 0.48058812761632724 + batch 8000 loss: 0.47159107415075413 + batch 9000 loss: 0.4712311488997657 + batch 10000 loss: 0.4675815168160479 + batch 11000 loss: 0.43210567891132085 + batch 12000 loss: 0.445208148030797 + batch 13000 loss: 0.4119230824254337 + batch 14000 loss: 0.4190662656680215 + batch 15000 loss: 0.4094535468676477 +LOSS train 0.4094535468676477 valid XLA +``` + +## basic_training_jax.py + +This file constructed by first copy & paste code fragments from this pytorch training tutorial: +https://pytorch.org/tutorials/beginner/introyt/trainingyt.html + +Then replacing torch optimizer with `optax` optimizer; and use `jax.grad` for +gradient instead of `torch.Tensor.backward()`. + +Then, you can train the model using jax ecosystem's training loop. This is meant to +showcase how easy is to integrate with Jax. + +Example run: +```bash +(xla2) hanq-macbookpro:examples hanq$ python basic_training_jax.py +Training set has 60000 instances +Validation set has 10000 instances +Pullover Ankle Boot Pullover Ankle Boot +tensor([[0.5279, 0.8340, 0.3131, 0.8608, 0.3668, 0.6192, 0.7453, 0.3261, 0.8872, + 0.1854], + [0.7414, 0.8309, 0.8127, 0.8866, 0.2475, 0.2664, 0.0327, 0.6918, 0.6010, + 0.2766], + [0.3304, 0.9135, 0.2762, 0.6737, 0.0480, 0.6150, 0.5610, 0.5804, 0.9607, + 0.6450], + [0.9464, 0.9439, 0.3122, 0.1814, 0.1194, 0.5012, 0.2058, 0.1170, 0.7377, + 0.7453]]) +tensor([1, 5, 3, 7]) +Total loss for this batch: 2.4054245948791504 +EPOCH 1: + batch 1000 loss: 1.0705260595591972 + batch 2000 loss: 1.0997755021179327 + batch 3000 loss: 1.0186579653513108 + batch 4000 loss: 0.9090727646966116 + batch 5000 loss: 0.8309370622411024 + batch 6000 loss: 0.8702225417760783 + batch 7000 loss: 0.8750176187023462 + batch 8000 loss: 0.9652624803795453 + batch 9000 loss: 0.8688667197711766 + batch 10000 loss: 0.8021814124770199 + batch 11000 loss: 0.8000540231048071 + batch 12000 loss: 0.9150884484921057 + batch 13000 loss: 0.819690621060171 + batch 14000 loss: 0.8569030471532278 + batch 15000 loss: 0.8740896808278603 +LOSS train 0.8740896808278603 valid 2.3132264614105225 +``` \ No newline at end of file diff --git a/experimental/torch_xla2/examples/_diffusion.py b/experimental/torch_xla2/examples/_diffusion.py new file mode 100644 index 00000000000..5eae15edf25 --- /dev/null +++ b/experimental/torch_xla2/examples/_diffusion.py @@ -0,0 +1,112 @@ +import functools + +import torch +from time import time +from diffusers import DiffusionPipeline +from torch.utils import _pytree as pytree + + +import torch_xla2 +import torch_xla2.functions +from torch_xla2.extra import torch_view, jax_view + +import jax +import torch.func + + +class CompiledModule: + + def __init__(self, model): + weights = model.state_dict() + weights.update(model.named_parameters()) + self._weights = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.move_to_device, weights) + self._model = model + + self._func_jitted_torch = None #torch_view(func_mod_jitted) + + + def _maybe_move_tensor(self, tensor): + if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torch_xla2.tensor.XLATensor2): + return torch_xla2.tensor.move_to_device(tensor) + return tensor + + def _make_jitted(self, args, kwargs): + static = [] + for i, a in enumerate(args): + if not isinstance(a, torch.Tensor): + static.append(i + 1) # weight is 0 + static_argnames = [] + for k, v in kwargs.items(): + if not isinstance(v, torch.Tensor): + static_argnames.append(k) + + def f(weights, *args, **kwargs): + weights, args, kwargs = torch_xla2.tensor.wrap((weights, args, kwargs)) + with torch_xla2.functions.XLAFunctionMode(), torch_xla2.tensor.XLADispatchMode(): + res = torch.func.functional_call(self._model, weights, args, kwargs) + if isinstance(res, tuple) and len(res) == 1: + res = res[0] + return torch_xla2.tensor.unwrap(res) + + fjit = jax.jit(f, static_argnames=tuple(static_argnames)) + return torch_view(fjit) + + + def forward(self, *args, **kwargs): + (args, kwargs) = pytree.tree_map(self._maybe_move_tensor, (args, kwargs)) + if self._func_jitted_torch is None: + self._func_jitted_torch = self._make_jitted(args, kwargs) + return self._func_jitted_torch( + self._weights, + *args, + **kwargs + ) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def __getattr__(self, key): + return getattr(self._model, key) + + +def compile_pipe(pipe): + pipe.text_encoder = CompiledModule(pipe.text_encoder) + pipe.text_encoder_2 = CompiledModule(pipe.text_encoder_2) + pipe.unet = CompiledModule(pipe.unet) + pipe.vae = CompiledModule(pipe.vae) + + +def main(): + pipe = DiffusionPipeline.from_pretrained( + # "stabilityai/stable-diffusion-xl-base-0.9", + "stabilityai/stable-diffusion-xl-base-1.0", + use_safetensors=True, + + ) + compile_pipe(pipe) + + global_bs = 10 + inference_steps = 20 + resol = 1024 + prompts = ["a photo of an astronaut riding a horse on mars"] * global_bs + print(f'global batch size {global_bs}', + f'inference steps {inference_steps}', + f'Image resolution {resol}', + flush=True + ) + + iters = 5 + for i in range(iters): + prompt = prompts + # print('per device prompts len',len(prompt)) + # prompt = prompts[rank] + start = time() + image = pipe(prompt, + num_inference_steps=inference_steps, + height=resol, + width=resol).images[0] + print(f'Step {i} inference time {time()-start} sec', flush=True) + + +if __name__ == '__main__': + main() diff --git a/experimental/torch_xla2/examples/basic_training.py b/experimental/torch_xla2/examples/basic_training.py new file mode 100644 index 00000000000..5d3f5a734c5 --- /dev/null +++ b/experimental/torch_xla2/examples/basic_training.py @@ -0,0 +1,197 @@ +""" +This is the script from this tutorial: +https://pytorch.org/tutorials/beginner/introyt/trainingyt.html + +Then, it's modified to make the training loop using Jax's grad +and optimizer +""" + +import torch +from torch.utils import _pytree as pytree +import torchvision +import torchvision.transforms as transforms +import torch_xla2 + +# PyTorch TensorBoard support +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime + + +transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,))]) + +# Create datasets for training & validation, download if necessary +training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True) +validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True) + +# Create data loaders for our datasets; shuffle for training, not for validation +training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True) +validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False) + +# Class labels +classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', + 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot') + +# Report split sizes +print('Training set has {} instances'.format(len(training_set))) +print('Validation set has {} instances'.format(len(validation_set))) + +import matplotlib.pyplot as plt +import numpy as np + +# Helper function for inline image display +def matplotlib_imshow(img, one_channel=False): + if one_channel: + img = img.mean(dim=0) + img = img / 2 + 0.5 # unnormalize + npimg = img.numpy() + if one_channel: + plt.imshow(npimg, cmap="Greys") + else: + plt.imshow(np.transpose(npimg, (1, 2, 0))) + +dataiter = iter(training_loader) +images, labels = next(dataiter) + +# Create a grid from the images and show them +img_grid = torchvision.utils.make_grid(images) +matplotlib_imshow(img_grid, one_channel=True) +print(' '.join(classes[labels[j]] for j in range(4))) + + +import torch.nn as nn +import torch.nn.functional as F + +# PyTorch models inherit from torch.nn.Module +class GarmentClassifier(nn.Module): + def __init__(self): + super(GarmentClassifier, self).__init__() + self.fc1 = nn.Linear(28 * 28, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +model = GarmentClassifier() + +loss_fn = torch.nn.CrossEntropyLoss() + +# NB: Loss functions expect data in batches, so we're creating batches of 4 +# Represents the model's confidence in each of the 10 classes for a given input +dummy_outputs = torch.rand(4, 10) +# Represents the correct class among the 10 being tested +dummy_labels = torch.tensor([1, 5, 3, 7]) + +print(dummy_outputs) +print(dummy_labels) + +loss = loss_fn(dummy_outputs, dummy_labels) +print('Total loss for this batch: {}'.format(loss.item())) + +# Optimizers specified in the torch.optim package + +# NEW: Move model to XLA device +state_dict = model.state_dict() +state_dict = pytree.tree_map_only(torch.Tensor, + torch_xla2.tensor.move_to_device, state_dict) +model.load_state_dict(state_dict, strict=False, assign=True) + +optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + +def train_one_epoch(epoch_index, tb_writer): + running_loss = 0. + last_loss = 0. + + # Here, we use enumerate(training_loader) instead of + # iter(training_loader) so that we can track the batch + # index and do some intra-epoch reporting + for i, data in enumerate(training_loader): + # Every data instance is an input + label pair + # NEW: Move model to XLA device + data = pytree.tree_map_only(torch.Tensor, + torch_xla2.tensor.move_to_device, data) + inputs, labels = data + + # Zero your gradients for every batch! + optimizer.zero_grad() + + # Make predictions for this batch + outputs = model(inputs) + + # Compute the loss and its gradients + loss = loss_fn(outputs, labels) + loss.backward() + + # Adjust learning weights + optimizer.step() + + # Gather data and report + running_loss += loss.item() + if i % 1000 == 999: + last_loss = running_loss / 1000 # loss per batch + print(' batch {} loss: {}'.format(i + 1, last_loss)) + tb_x = epoch_index * len(training_loader) + i + 1 + tb_writer.add_scalar('Loss/train', last_loss, tb_x) + running_loss = 0. + + return last_loss + + + +# Initializing in a separate cell so we can easily add more epochs to the same run +timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') +writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) +epoch_number = 0 +EPOCHS = 2 +best_vloss = 1_000_000. + +for epoch in range(EPOCHS): + print('EPOCH {}:'.format(epoch_number + 1)) + + # Make sure gradient tracking is on, and do a pass over the data + model.train(True) + + + avg_loss = train_one_epoch(epoch_number, writer) + + running_vloss = 0.0 + # Set the model to evaluation mode, disabling dropout and using population + # statistics for batch normalization. + model.eval() + + # Disable gradient computation and reduce memory consumption. + with torch.no_grad(): + for i, vdata in enumerate(validation_loader): + # NOTE: move to XLA device + vinputs, vlabels = pytree.tree_map_only( + torch.Tensor, + torch_xla2.tensor.move_to_device, + vdata) + voutputs = model(vinputs) # call model's forward + vloss = loss_fn(voutputs, vlabels) + running_vloss += vloss + + avg_vloss = running_vloss / (i + 1) + print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) + + # Log the running loss averaged per batch + # for both training and validation + writer.add_scalars('Training vs. Validation Loss', + { 'Training' : avg_loss, 'Validation' : avg_vloss }, + epoch_number + 1) + writer.flush() + + # Track best performance, and save the model's state + if avg_vloss < best_vloss: + best_vloss = avg_vloss + model_path = 'model_{}_{}'.format(timestamp, epoch_number) + torch.save(model.state_dict(), model_path) + + epoch_number += 1 \ No newline at end of file diff --git a/experimental/torch_xla2/examples/basic_training_jax.py b/experimental/torch_xla2/examples/basic_training_jax.py new file mode 100644 index 00000000000..3941fcdf8fe --- /dev/null +++ b/experimental/torch_xla2/examples/basic_training_jax.py @@ -0,0 +1,196 @@ +""" +This is the script from this tutorial: +https://pytorch.org/tutorials/beginner/introyt/trainingyt.html +""" + +import torch +from torch.utils import _pytree as pytree +import torchvision +import torchvision.transforms as transforms +import torch_xla2 +import torch_xla2.extra +import jax +import optax +import numpy as np + +# PyTorch TensorBoard support +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime + + +transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,))]) + +# Create datasets for training & validation, download if necessary +training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True) +validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True) + +# Create data loaders for our datasets; shuffle for training, not for validation +training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True) +validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False) + +# Class labels +classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', + 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot') + +# Report split sizes +print('Training set has {} instances'.format(len(training_set))) +print('Validation set has {} instances'.format(len(validation_set))) + +import matplotlib.pyplot as plt +import numpy as np + +# Helper function for inline image display +def matplotlib_imshow(img, one_channel=False): + if one_channel: + img = img.mean(dim=0) + img = img / 2 + 0.5 # unnormalize + npimg = img.numpy() + if one_channel: + plt.imshow(npimg, cmap="Greys") + else: + plt.imshow(np.transpose(npimg, (1, 2, 0))) + +dataiter = iter(training_loader) +images, labels = next(dataiter) + +# Create a grid from the images and show them +img_grid = torchvision.utils.make_grid(images) +matplotlib_imshow(img_grid, one_channel=True) +print(' '.join(classes[labels[j]] for j in range(4))) + + +import torch.nn as nn +import torch.nn.functional as F + +# PyTorch models inherit from torch.nn.Module +class GarmentClassifier(nn.Module): + def __init__(self): + super(GarmentClassifier, self).__init__() + self.fc1 = nn.Linear(28 * 28, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +model = GarmentClassifier() +loss_fn = torch.nn.CrossEntropyLoss() + +jax_weights, jax_func = torch_xla2.extract_jax(model) +jax_func = jax.jit(jax_func, inline=True) +jax_optimizer = optax.adam(0.01) +opt_state = jax_optimizer.init(jax_weights) + + +def jax_loss(weights, data, label): + pred = jax_func(weights, data) + loss = torch_xla2.extra.call_torch(loss_fn, pred, label) + return loss + +grad_fn = jax.jit(jax.value_and_grad(jax_loss)) + + +# NB: Loss functions expect data in batches, so we're creating batches of 4 +# Represents the model's confidence in each of the 10 classes for a given input +dummy_outputs = torch.rand(4, 10) +# Represents the correct class among the 10 being tested +dummy_labels = torch.tensor([1, 5, 3, 7]) + +print(dummy_outputs) +print(dummy_labels) + +loss = loss_fn(dummy_outputs, dummy_labels) +print('Total loss for this batch: {}'.format(loss.item())) + + +def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): + + running_loss = 0. + last_loss = 0. + + # Here, we use enumerate(training_loader) instead of + # iter(training_loader) so that we can track the batch + # index and do some intra-epoch reporting + for i, data in enumerate(training_loader): + # Every data instance is an input + label pair + # NEW: Move model to XLA device + data = pytree.tree_map_only(torch.Tensor, + torch_xla2.tensor.t2j, data) + inputs, labels = data + + val, grads = grad_fn(jax_weights, (inputs, ), labels) + updates, opt_state = jax_optimizer.update(grads, opt_state) + jax_weights = optax.apply_updates(jax_weights, updates) + + # Gather data and report + running_loss += val.item() + if i % 1000 == 999: + last_loss = running_loss / 1000 # loss per batch + print(' batch {} loss: {}'.format(i + 1, last_loss)) + tb_x = epoch_index * len(training_loader) + i + 1 + tb_writer.add_scalar('Loss/train', last_loss, tb_x) + running_loss = 0. + + return last_loss, opt_state + + + +# Initializing in a separate cell so we can easily add more epochs to the same run +timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') +writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) +epoch_number = 0 +EPOCHS = 2 +best_vloss = 1_000_000. + +for epoch in range(EPOCHS): + print('EPOCH {}:'.format(epoch_number + 1)) + + # Make sure gradient tracking is on, and do a pass over the data + model.train(True) + + # NEW: Move model to XLA device + state_dict = model.state_dict() + state_dict = pytree.tree_map_only(torch.Tensor, + torch_xla2.tensor.move_to_device, state_dict) + model.load_state_dict(state_dict, strict=False, assign=True) + + avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer) + + running_vloss = 0.0 + # Set the model to evaluation mode, disabling dropout and using population + # statistics for batch normalization. + model.eval() + + # Disable gradient computation and reduce memory consumption. + with torch.no_grad(): + for i, vdata in enumerate(validation_loader): + + vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata) + voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward + vloss = torch_xla2.extra.call_torch(loss_fn, voutputs, vlabels) + running_vloss += vloss + + avg_vloss = running_vloss / (i + 1) + print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) + + # Log the running loss averaged per batch + # for both training and validation + writer.add_scalars('Training vs. Validation Loss', + { 'Training' : np.asarray(avg_loss), 'Validation' : np.asarray(avg_vloss) }, + epoch_number + 1) + writer.flush() + + # Track best performance, and save the model's state + if avg_vloss < best_vloss: + best_vloss = avg_vloss + model_path = 'model_{}_{}'.format(timestamp, epoch_number) + torch.save(model.state_dict(), model_path) + + epoch_number += 1 \ No newline at end of file diff --git a/experimental/torch_xla2/examples/eager_mode.py b/experimental/torch_xla2/examples/eager_mode.py new file mode 100644 index 00000000000..358ee6256c6 --- /dev/null +++ b/experimental/torch_xla2/examples/eager_mode.py @@ -0,0 +1,42 @@ + +from torch_xla2.tensor import move_to_device +import torch_xla2 +from torch import nn +from torch.nn import functional as F +import torch +from torch.utils import _pytree as pytree + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + +m = MyModel() + +# Execute this model using torch +inputs = (torch.randn(3, 3, 28, 28), ) + +inputs, state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, (inputs, m.state_dict())) +m.load_state_dict(state_dict, strict=False, assign=True) +print(m(*inputs)) +print('---=====') + +from torch_xla2.extra import jax_jit + +@jax_jit +def model_func(param, inputs): + return torch.func.functional_call(m, param, inputs) + +print(model_func(state_dict, inputs)) + + diff --git a/experimental/torch_xla2/examples/requirements.txt b/experimental/torch_xla2/examples/requirements.txt new file mode 100644 index 00000000000..69e01ff3dd0 --- /dev/null +++ b/experimental/torch_xla2/examples/requirements.txt @@ -0,0 +1,3 @@ +torchvision +matplotlib +optax \ No newline at end of file diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/experimental/torch_xla2/test/llama/test_llama.py index 69ec3c33aef..dae7bf0cc5c 100644 --- a/experimental/torch_xla2/test/llama/test_llama.py +++ b/experimental/torch_xla2/test/llama/test_llama.py @@ -33,7 +33,6 @@ def test_can_run(self): # NOTE: this API does NOT use torch export weights, jax_func = torch_xla2.extract_jax(m) - print(jax_func(weights, sample_args)) def test_can_run_exportable(self): diff --git a/experimental/torch_xla2/torch_xla2/_ops.py b/experimental/torch_xla2/torch_xla2/_ops.py index adbcef5eaad..fe0f97a0f01 100644 --- a/experimental/torch_xla2/torch_xla2/_ops.py +++ b/experimental/torch_xla2/torch_xla2/_ops.py @@ -1684,3 +1684,62 @@ def _aten_scalar_tensor(s, dtype = tensor.t2j_dtype(dtype) return jnp.array(s, dtype=dtype) return jnp.array(s) + + +@op(torch.ops.aten.to.device) +def _aten_to_device(x,device, dtype): + return x + + +@op(torch.ops.aten.max_pool2d_with_indices_backward) +def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices): + + """ + Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. + + Args: + grad_output: The gradient tensor from the preceding layer. + self: The input tensor on which the original max pooling was performed. + kernel_size: The size of the pooling window. + stride: The stride of the pooling window. + padding: The padding applied during max pooling. + dilation: The dilation factor for the pooling operation. + ceil_mode: Whether to use ceil or floor when calculating output shapes. + indices: The indices of the maximum values, as produced by max_pool2d_with_indices. + + Returns: + The calculated gradient with respect to the input (grad_input). + """ + + kH, kW = kernel_size + dH, dW = stride + padH, padW = padding + dilH, dilW = dilation + + # Calculate output shape (may need adjustment based on ceil_mode) + out_shape = jnp.array(self.shape) + grad_input = jnp.zeros_like(self) + + # Iterate over the flattened input and output tensors + for i, idx in enumerate(indices.flatten()): + # Calculate input coordinates corresponding to the maximum value + out_y, out_x = i // grad_output.shape[3], i % grad_output.shape[3] + in_y = out_y * dH - padH + out_y * (dilH - 1) + in_x = out_x * dW - padW + out_x * (dilW - 1) + + # Scatter the gradient to the appropriate input locations (handling potential overlaps) + for y in range(in_y, in_y + kH): + for x in range(in_x, in_x + kW): + if 0 <= y < grad_input.shape[2] and 0 <= x < grad_input.shape[3]: + grad_input = grad_input.at[y, x].add(grad_output.flatten()[i]) + + return grad_input + + +@op(torch.ops.aten._local_scalar_dense) +def _aten_local_scalar_dense(x): + return x.item() + +@op(torch.ops.aten.tensor_split.sections) +def _aten_tensor_split(ary, indices_or_sections, axis=0): + return jnp.array_split(ary, indices_or_sections, axis) \ No newline at end of file diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 86a9f9c9c9f..98953a8b04c 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -199,7 +199,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if isinstance(func, torch._ops.OpOverloadPacket): return func(*args, **kwargs) - print(func.name()) if func.name() == 'aten::copy_': x, y = args x._elem = y._elem