diff --git a/experimental/torch_xla2/examples/mnist_tpu.ipynb b/experimental/torch_xla2/examples/mnist_tpu.ipynb new file mode 100644 index 00000000000..ff41f276fce --- /dev/null +++ b/experimental/torch_xla2/examples/mnist_tpu.ipynb @@ -0,0 +1,647 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tHNudaKYx4Ci", + "outputId": "d72d15a1-483f-4820-de3c-0ef8905cb1ed" + }, + "outputs": [], + "source": [ + "# Uncomment and run these if you haven't already installed `torch_xla2`\n", + "#!pip uninstall -y tensorflow\n", + "#!pip install tpu-info 'torch_xla2[tpu] @ git+https://github.com/pytorch/xla.git#subdirectory=experimental/torch_xla2' -f https://storage.googleapis.com/libtpu-releases/index.html\n", + "#!pip install torchvision" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Distributed training with `torch_xla2`\n", + "\n", + "This Notebook demonstrates how to perform distributed training using `torch_xla2`, which allows you to run PyTorch models with JAX." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataset and model setup\n", + "\n", + "Below, we download and preprocess the MNIST dataset and instantiate a simple neural network to use as an example. The details here aren't important here. You can follow the same steps below for any PyTorch model and dataset.\n", + "\n", + "A couple of important notes about this section:\n", + "\n", + "- When we're loading data, the batch will be split across all local devices.\n", + "- `model` remains on the CPU device. We'll move it to the TPU in the next step." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "dbNWnxtizF-Z" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "\n", + "train_dataset = torchvision.datasets.MNIST(\n", + " root='./data',\n", + " train=True,\n", + " download=True,\n", + " transform=transforms.Compose(\n", + " [transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))]))\n", + "test_dataset = torchvision.datasets.MNIST(\n", + " root='./data',\n", + " train=False,\n", + " download=True,\n", + " transform=transforms.Compose(\n", + " [transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))]))\n", + "\n", + "train_loader = torch.utils.data.DataLoader(\n", + " train_dataset,\n", + " batch_size=128,\n", + " drop_last=True,\n", + " shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(\n", + " test_dataset,\n", + " batch_size=128,\n", + " drop_last=True,\n", + " shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "_p2gxDdv6RYo" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(784, 512),\n", + " nn.ReLU(),\n", + " nn.Linear(512, 512),\n", + " nn.ReLU(),\n", + " nn.Linear(512, 512),\n", + " nn.ReLU(),\n", + " nn.Linear(512, 10)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Replicating the model across devices\n", + "\n", + "Most TPU configurations include multiple TPU cores per host. For example, a v4-8 TPU has 4 chips total. We can use `tpu-info` to see how many devices are available on this host." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[3mTPU Chips \u001b[0m\n", + "┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mDevice \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mType \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mCores\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mPID \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━┩\n", + "│ /dev/accel0 │ TPU v4 chip │ 1 │ None │\n", + "│ /dev/accel1 │ TPU v4 chip │ 1 │ None │\n", + "│ /dev/accel2 │ TPU v4 chip │ 1 │ None │\n", + "│ /dev/accel3 │ TPU v4 chip │ 1 │ None │\n", + "└─────────────┴─────────────┴───────┴──────┘\n", + "Libtpu metrics unavailable. Did you start a workload with `TPU_RUNTIME_METRICS_PORTS=8431,8432,8433,8434`?\n" + ] + } + ], + "source": [ + "!tpu-info" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`torch_xla2` uses JAX as a backend, so we can use JAX to double-check the device count. Don't worry -- we won't have to directly use JAX to run the model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "M1wGEXY4yRvG", + "outputId": "4bea9105-062d-45d6-bd37-d47e9d06cad6" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax\n", + "\n", + "# The TPU core count will vary depending on your environment.\n", + "jax.device_count()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The device count above should match the output of `tpu-info` (4 devices in the case of a v4-8).\n", + "\n", + "In this example, we'll use `torch_xla2`'s custom `DistributedDataParallel` implementation to replicate the model parameters across all available TPU devices and split input data between each core." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "Y9uhN5Om0f25" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/wcromar/tx2/.venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:270: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import torch_xla2\n", + "\n", + "ddp_model = torch_xla2.distributed.DistributedDataParallel(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can dig into the underlying JAX array to see that there's an identical copy of the parameter tensor on each TPU device:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "example_param = next(ddp_model.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Shard(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=0, data=[[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", + " 0.00225713]\n", + " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", + " 0.01084254]\n", + " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", + " -0.02596978]\n", + " ...\n", + " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", + " -0.02620777]\n", + " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", + " -0.02689848]\n", + " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", + " 0.00852821]]),\n", + " Shard(device=TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=1, data=[[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", + " 0.00225713]\n", + " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", + " 0.01084254]\n", + " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", + " -0.02596978]\n", + " ...\n", + " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", + " -0.02620777]\n", + " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", + " -0.02689848]\n", + " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", + " 0.00852821]]),\n", + " Shard(device=TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=2, data=[[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", + " 0.00225713]\n", + " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", + " 0.01084254]\n", + " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", + " -0.02596978]\n", + " ...\n", + " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", + " -0.02620777]\n", + " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", + " -0.02689848]\n", + " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", + " 0.00852821]]),\n", + " Shard(device=TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=3, data=[[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", + " 0.00225713]\n", + " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", + " 0.01084254]\n", + " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", + " -0.02596978]\n", + " ...\n", + " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", + " -0.02620777]\n", + " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", + " -0.02689848]\n", + " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", + " 0.00852821]])]\n" + ] + } + ], + "source": [ + "import pprint\n", + "pprint.pprint(example_param._elem.addressable_shards)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The replicated tensor still behaves as a plain PyTorch tensor, however:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "XLATensor2( [[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n", + " 0.00225713]\n", + " [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n", + " 0.01084254]\n", + " [-0.01985117 -0.01139126 -0.00223861 ... -0.02136385 0.0339912\n", + " -0.02596978]\n", + " ...\n", + " [ 0.0168394 0.0063334 -0.02949585 ... -0.0254653 0.03273752\n", + " -0.02620777]\n", + " [-0.00896274 -0.03342744 -0.0269749 ... 0.01811987 0.03423703\n", + " -0.02689848]\n", + " [ 0.01867637 0.0117135 0.02216029 ... 0.00011777 0.02212651\n", + " 0.00852821]])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example_param" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sharding inputs\n", + "\n", + "Unlike the model parameters, we want to send a different shard of the input data to each device. We'll take one batch of images as an example:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([128, 1, 28, 28])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example_images, _ = next(iter(train_loader))\n", + "example_images.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sharding the input batch across devices does not change the overall size of the tensor:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 1, 28, 28)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sharded_example_images = ddp_model.shard_input(example_images)\n", + "sharded_example_images.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we dig into the underlying JAX array, we can see that the input has been split (into quarters in this case) across the batch dimension:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(32, 1, 28, 28), (32, 1, 28, 28), (32, 1, 28, 28), (32, 1, 28, 28)]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[s.data.shape for s in sharded_example_images._elem.addressable_shards]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Putting it all together\n", + "\n", + "`torch_xla2` allows us to seamlessly shard and replicate tensors across devices, while still maintaining a singular view of that tensor through PyTorch. With some minor changes, we can adapt the conventional PyTorch training loop to use the TPU.\n", + "\n", + "Note that we do not have to spawn any child processes. Although each parameter and input is represented by one tensor, that tensor is already distributed across multiple devices." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The loss function and optimizer stay the same:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "E5QjcpuY1hx5" + }, + "outputs": [], + "source": [ + "loss_fn = torch.nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "JAX gets significantly better performance when compiled, normally through `jax.jit`. `torch_xla2`'s DDP implementation contains a utility `jit_step` that can be used to compile a training step. Note that for this to work, the training step must be separated out into a function. Otherwise, the actual contents are the same as they would be for eager CPU or GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "AojhVVzx0ZEG" + }, + "outputs": [], + "source": [ + "@ddp_model.jit_step\n", + "def train_step(sharded_inputs, sharded_labels):\n", + " optimizer.zero_grad()\n", + " outputs = ddp_model(sharded_inputs)\n", + " loss = loss_fn(outputs, sharded_labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, let's quickly run training for several epochs and check the validation results:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "QhO7V7JR2l8A" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " batch 0 loss: 2.3075523376464844\n", + " batch 100 loss: 2.3029651641845703\n", + " batch 200 loss: 2.2921366691589355\n", + " batch 300 loss: 2.2877070903778076\n", + " batch 400 loss: 2.274242401123047\n", + "Epoch 1\n", + " batch 0 loss: 2.2708349227905273\n", + " batch 100 loss: 2.269294261932373\n", + " batch 200 loss: 2.2480335235595703\n", + " batch 300 loss: 2.243983268737793\n", + " batch 400 loss: 2.2470455169677734\n", + "Epoch 2\n", + " batch 0 loss: 2.234013557434082\n", + " batch 100 loss: 2.2184624671936035\n", + " batch 200 loss: 2.2029666900634766\n", + " batch 300 loss: 2.198725461959839\n", + " batch 400 loss: 2.1829864978790283\n", + "Epoch 3\n", + " batch 0 loss: 2.1811957359313965\n", + " batch 100 loss: 2.1297898292541504\n", + " batch 200 loss: 2.1378531455993652\n", + " batch 300 loss: 2.0720174312591553\n", + " batch 400 loss: 2.0413732528686523\n", + "Epoch 4\n", + " batch 0 loss: 2.046309471130371\n", + " batch 100 loss: 1.9817270040512085\n", + " batch 200 loss: 1.9381718635559082\n", + " batch 300 loss: 1.847656011581421\n", + " batch 400 loss: 1.808678388595581\n", + "Epoch 5\n", + " batch 0 loss: 1.7617125511169434\n", + " batch 100 loss: 1.768508791923523\n", + " batch 200 loss: 1.6427236795425415\n", + " batch 300 loss: 1.6908036470413208\n", + " batch 400 loss: 1.538255214691162\n", + "Epoch 6\n", + " batch 0 loss: 1.4774806499481201\n", + " batch 100 loss: 1.4533928632736206\n", + " batch 200 loss: 1.2804057598114014\n", + " batch 300 loss: 1.2498115301132202\n", + " batch 400 loss: 1.116618275642395\n", + "Epoch 7\n", + " batch 0 loss: 1.1049035787582397\n", + " batch 100 loss: 1.0565766096115112\n", + " batch 200 loss: 1.0216108560562134\n", + " batch 300 loss: 0.9548335671424866\n", + " batch 400 loss: 0.8766275644302368\n", + "Epoch 8\n", + " batch 0 loss: 0.7384852766990662\n", + " batch 100 loss: 0.8499367237091064\n", + " batch 200 loss: 0.8409233689308167\n", + " batch 300 loss: 0.7746399641036987\n", + " batch 400 loss: 0.8063997030258179\n", + "Epoch 9\n", + " batch 0 loss: 0.7310354709625244\n", + " batch 100 loss: 0.825514018535614\n", + " batch 200 loss: 0.6718677878379822\n", + " batch 300 loss: 0.7210809588432312\n", + " batch 400 loss: 0.7002769708633423\n" + ] + } + ], + "source": [ + "for epoch in range(10):\n", + " running_loss = 0\n", + "\n", + " print('Epoch', epoch)\n", + " for i, data in enumerate(train_loader):\n", + " inputs, labels = data\n", + " # Distribute the batch across all TPU cores\n", + " sharded_inputs, sharded_labels = ddp_model.shard_input(inputs), ddp_model.shard_input(labels)\n", + " loss = train_step(sharded_inputs, sharded_labels)\n", + "\n", + " if i % 100 == 0:\n", + " print(' batch {} loss: {}'.format(i, loss.item()))\n", + " running_loss = 0." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss 0.6315549612045288\n" + ] + } + ], + "source": [ + "@ddp_model.jit_step\n", + "def eval_step(sharded_vinputs, sharded_vlabels):\n", + " voutputs = ddp_model(sharded_vinputs)\n", + " vloss = loss_fn(voutputs, sharded_vlabels)\n", + " return vloss\n", + "\n", + "ddp_model.eval()\n", + "running_vloss = 0.\n", + "\n", + "# Disable gradient computation and reduce memory consumption.\n", + "with torch.no_grad():\n", + " for i, vdata in enumerate(test_loader):\n", + " vinputs, vlabels = vdata\n", + " sharded_vinputs, sharded_vlabels = ddp_model.shard_input(vinputs), ddp_model.shard_input(vlabels)\n", + " vloss = eval_step(sharded_vinputs, sharded_vlabels)\n", + " running_vloss += vloss\n", + "\n", + "avg_vloss = running_vloss / (i + 1)\n", + "print('Validation loss', avg_vloss.item())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "With some minor changes to your training loop, `torch_xla2` allows you to distribute a model across multiple devices and run a compiled version with JAX. All of the data you interact with directly is still a `torch` tensor, and JAX handles all of the distributed details in the background.\n", + "\n", + "`torch_xla2` (and especially training) is still under heavy development. To learn more about the project and its current status, see https://github.com/pytorch/xla/tree/master/experimental/torch_xla2" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V28", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/experimental/torch_xla2/torch_xla2/distributed.py b/experimental/torch_xla2/torch_xla2/distributed.py index 4ad6dba58f4..dacc4d69765 100644 --- a/experimental/torch_xla2/torch_xla2/distributed.py +++ b/experimental/torch_xla2/torch_xla2/distributed.py @@ -208,7 +208,7 @@ def shard_input(self, inp): global_batch_shape = (global_batch_size,) + inp.shape[1:] sharding = NamedSharding(self._mesh, P("batch")) - return jax.make_array_from_single_device_arrays( + return self._env.j2t_iso(jax.make_array_from_single_device_arrays( global_batch_shape, NamedSharding(self._mesh, P("batch")), arrays=[ @@ -217,7 +217,7 @@ def shard_input(self, inp): per_replica_batches, sharding.addressable_devices ) ], - ) + )) def replicate_input(self, inp): return self._env.j2t_iso(