diff --git a/README.md b/README.md index 9a1a340d9c1..43a78266685 100644 --- a/README.md +++ b/README.md @@ -96,8 +96,7 @@ If you're using `DistributedDataParallel`, make the following changes: + dist.init_process_group("xla", init_method='xla://') + + model.to(xm.xla_device()) -+ # `gradient_as_bucket_view=True` required for XLA -+ ddp_model = DDP(model, gradient_as_bucket_view=True) ++ ddp_model = DDP(model) - model = model.to(rank) - ddp_model = DDP(model, device_ids=[rank]) diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index 5a517565847..b96bb346643 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -1 +1,522 @@ -{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Distributed PyTorch/XLA Basics\n\nBeginning with PyTorch/XLA 2.0, Kaggle supports the new PJRT preview runtime on TPU VMs! For more information about PJRT, see the [PyTorch/XLA GitHub repository](https://github.com/pytorch/xla/blob/master/docs/pjrt.md).\n\nPyTorch/XLA is a package that lets PyTorch run on TPU devices. Kaggle provides a free v3-8 TPU VM. v3-8 TPUs have 8 logical devices: 4 TPU chips, each having 2 cores. This notebook shows how to run simple distributed operations on a TPU using the PJRT runtime. For more information about the Cloud TPU architecture, [see the official documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).\n\nAt the time of writing Kaggle Notebooks on TPU VM are preinstalled with Python 3.10 and PT/XLA 2.1. See below for the exact versions.","metadata":{}},{"cell_type":"code","source":"!python --version","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:01.251567Z","iopub.execute_input":"2024-01-10T19:30:01.251900Z","iopub.status.idle":"2024-01-10T19:30:01.378200Z","shell.execute_reply.started":"2024-01-10T19:30:01.251872Z","shell.execute_reply":"2024-01-10T19:30:01.377121Z"},"trusted":true},"execution_count":1,"outputs":[{"name":"stdout","text":"Python 3.10.13\n","output_type":"stream"}]},{"cell_type":"code","source":"import torch\ntorch.__version__","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:01.380122Z","iopub.execute_input":"2024-01-10T19:30:01.380390Z","iopub.status.idle":"2024-01-10T19:30:22.624453Z","shell.execute_reply.started":"2024-01-10T19:30:01.380364Z","shell.execute_reply":"2024-01-10T19:30:22.623753Z"},"trusted":true},"execution_count":2,"outputs":[{"execution_count":2,"output_type":"execute_result","data":{"text/plain":"'2.1.0+cu121'"},"metadata":{}}]},{"cell_type":"code","source":"import torch_xla\ntorch_xla.__version__","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:22.625450Z","iopub.execute_input":"2024-01-10T19:30:22.625785Z","iopub.status.idle":"2024-01-10T19:30:28.439813Z","shell.execute_reply.started":"2024-01-10T19:30:22.625759Z","shell.execute_reply":"2024-01-10T19:30:28.439042Z"},"trusted":true},"execution_count":3,"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n","output_type":"stream"},{"execution_count":3,"output_type":"execute_result","data":{"text/plain":"'2.1.0+libtpu'"},"metadata":{}}]},{"cell_type":"markdown","source":"Unlike JAX or TensorFlow, the convention in PyTorch is to start a separate child process per device to minimize the impact of Python's [Global Interpreter Lock](https://en.wikipedia.org/wiki/Global_interpreter_lock). In eager PyTorch, this means spawning one child process per GPU. For more information, see [PyTorch's distributed training documentation](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#comparison-between-dataparallel-and-distributeddataparallel).\n\nDue to architectural constraints, it is not possible for more than one process to access a TPU chip simultaneously. Because TPU v3 has two TensorCores cores per TPU chip, that means that each process must drive at least two TPU cores. By default, PyTorch/XLA will spawn 4 processes in total (one per chip), each having two threads (one per TensorCore). This is handled transparently by `xmp.spawn`, which mirrors `mp.spawn`. However, it is important to keep in mind that _all distributed workloads on a TPU v2 or v3 are multithreaded_. The function you pass to `spawn` should be thread-safe.\n\nTPU v4 has a different architecture, where each TPU chip is represented to PyTorch as a single device, so we spawn one process per device as expected.\n\nSee the [Cloud TPU documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) for an in-depth look at TPU architecture.","metadata":{}},{"cell_type":"code","source":"!printenv","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Temporary hack: remove some TPU environment variables to support multiprocessing\n# These will be set later by xmp.spawn.\n\nimport os\nos.environ.pop('TPU_PROCESS_ADDRESSES')\nos.environ.pop('CLOUD_TPU_TASK_ID')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:28.574526Z","iopub.execute_input":"2024-01-10T19:30:28.574793Z","iopub.status.idle":"2024-01-10T19:30:28.580178Z","shell.execute_reply.started":"2024-01-10T19:30:28.574764Z","shell.execute_reply":"2024-01-10T19:30:28.579554Z"},"trusted":true},"execution_count":5,"outputs":[{"execution_count":5,"output_type":"execute_result","data":{"text/plain":"'0'"},"metadata":{}}]},{"cell_type":"code","source":"import torch_xla.core.xla_model as xm\nimport torch_xla.distributed.xla_multiprocessing as xmp","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:28.581117Z","iopub.execute_input":"2024-01-10T19:30:28.581455Z","iopub.status.idle":"2024-01-10T19:30:28.606274Z","shell.execute_reply.started":"2024-01-10T19:30:28.581429Z","shell.execute_reply":"2024-01-10T19:30:28.605671Z"},"trusted":true},"execution_count":6,"outputs":[]},{"cell_type":"markdown","source":"To get the current process/thread's default XLA device, use `xm.xla_device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`.","metadata":{}},{"cell_type":"code","source":"import multiprocessing as mp\nlock = mp.Manager().Lock()\n\ndef print_device(i, lock):\n device = xm.xla_device()\n with lock:\n print('process', i, device)","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:28.607138Z","iopub.execute_input":"2024-01-10T19:30:28.607393Z","iopub.status.idle":"2024-01-10T19:30:28.664032Z","shell.execute_reply.started":"2024-01-10T19:30:28.607368Z","shell.execute_reply":"2024-01-10T19:30:28.662583Z"},"trusted":true},"execution_count":7,"outputs":[]},{"cell_type":"markdown","source":"To run a function on each TPU device, pass it to `xmp.spawn`. We'll use an `mp.Lock` to prevent `print` statements from overlapping between processes. This make the output clearer, but it is optional.\n\nNote: in interactive notebooks, you must use `start_method='fork'`.","metadata":{}},{"cell_type":"code","source":"xmp.spawn(print_device, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:28.666339Z","iopub.execute_input":"2024-01-10T19:30:28.666657Z","iopub.status.idle":"2024-01-10T19:30:33.218095Z","shell.execute_reply.started":"2024-01-10T19:30:28.666621Z","shell.execute_reply":"2024-01-10T19:30:33.216950Z"},"trusted":true},"execution_count":8,"outputs":[{"name":"stdout","text":"process 4 xla:0\nprocess 0 xla:0\nprocess 1 xla:1\nprocess 6 xla:0\nprocess 7 xla:1\nprocess 5 xla:1\nprocess 2 xla:0\nprocess 3 xla:1\n","output_type":"stream"}]},{"cell_type":"markdown","source":"Note: ignore the errors from `oauth2_credentials.cc`. These will be fixed in a future release.","metadata":{}},{"cell_type":"markdown","source":"To run `torch` operations on a TPU, pass the corresponding XLA device in as the `device` parameter. When you pass in an XLA device, the operation is added to a graph, which is executed lazily as needed. To force all devices to evaluate the current graph, call `xm.mark_step()`.","metadata":{}},{"cell_type":"code","source":"def add_ones(i, lock):\n x = torch.ones((3, 3), device=xm.xla_device())\n y = x + x\n \n # Run graph to compute `y` before printing\n xm.mark_step()\n \n with lock:\n print(i, y)\n\nxmp.spawn(add_ones, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:33.219569Z","iopub.execute_input":"2024-01-10T19:30:33.219878Z","iopub.status.idle":"2024-01-10T19:30:35.653084Z","shell.execute_reply.started":"2024-01-10T19:30:33.219847Z","shell.execute_reply":"2024-01-10T19:30:35.651887Z"},"trusted":true},"execution_count":9,"outputs":[{"name":"stdout","text":"6 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n7 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n0 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n2 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n3 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n1 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n4 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n5 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n","output_type":"stream"}]},{"cell_type":"markdown","source":"To communicate tensors between TPU devices, use the collective communication operations in `xla_model`, such as `all_gather`.\n\n","metadata":{}},{"cell_type":"code","source":"def gather_ids(i, lock):\n # Create a tensor on each device with the device ID\n t = torch.tensor([i], device=xm.xla_device())\n with lock:\n print(i, t)\n \n # Collect and concatenate the IDs\n ts = xm.all_gather(t)\n xm.mark_step()\n with lock:\n print(i, ts)\n\nxmp.spawn(gather_ids, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:35.656377Z","iopub.execute_input":"2024-01-10T19:30:35.656796Z","iopub.status.idle":"2024-01-10T19:30:38.314318Z","shell.execute_reply.started":"2024-01-10T19:30:35.656763Z","shell.execute_reply":"2024-01-10T19:30:38.313118Z"},"trusted":true},"execution_count":10,"outputs":[{"name":"stdout","text":"7 tensor([7], device='xla:1')\n6 tensor([6], device='xla:0')\n4 tensor([4], device='xla:0')\n5 tensor([5], device='xla:1')\n3 tensor([3], device='xla:1')\n2 tensor([2], device='xla:0')\n1 tensor([1], device='xla:1')\n0 tensor([0], device='xla:0')\n7 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n6 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n5 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n4 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n3 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n2 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n0 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n1 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n","output_type":"stream"}]},{"cell_type":"markdown","source":"PyTorch/XLA 2.0 also ships with experimental support for the `torch.distributed` using the `pjrt://` `init_method`, including `DistributedDataParallel`.\n\nBecause replicas are run multithreaded, the distributed function must be thread-safe. However, the global RNG that `torch` uses for module initialization will give inconsistent results between replicas on TPU v3, since there will be multiple threads concurrently using it. To ensure consistent parameters, we recommend broadcasting model parameters from replica 0 to the other replicas using `pjrt.broadcast_master_param`. In practice, you may also load each replica's parameters from a common checkpoint.","metadata":{}},{"cell_type":"code","source":"import torch.distributed as dist\nimport torch.nn as nn\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nimport torch.optim as optim\n\nimport torch_xla.distributed.xla_backend # Registers `xla://` init_method\nimport torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n\ndef toy_model(index, lock):\n device = xm.xla_device()\n dist.init_process_group('xla', init_method='xla://')\n\n # Initialize a basic toy model\n torch.manual_seed(42)\n model = nn.Linear(128, 10).to(device)\n\n # Optional for TPU v4 and GPU\n xm.broadcast_master_param(model)\n\n # `gradient_as_bucket_view=True` required for XLA\n model = DDP(model, gradient_as_bucket_view=True)\n\n loss_fn = nn.MSELoss()\n optimizer = optim.SGD(model.parameters(), lr=.001)\n\n for i in range(10):\n # Generate random inputs and outputs on the XLA device\n data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)\n\n optimizer.zero_grad()\n output = model(data)\n loss = loss_fn(output, target)\n loss.backward()\n\n optimizer.step()\n \n # Run the pending graph\n xm.mark_step()\n\n with lock:\n # Print mean parameters so we can confirm they're the same across replicas\n print(index, [p.mean() for p in model.parameters()])\n\nxmp.spawn(toy_model, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:38.315653Z","iopub.execute_input":"2024-01-10T19:30:38.315927Z","iopub.status.idle":"2024-01-10T19:30:43.491104Z","shell.execute_reply.started":"2024-01-10T19:30:38.315899Z","shell.execute_reply":"2024-01-10T19:30:43.490078Z"},"trusted":true},"execution_count":11,"outputs":[{"name":"stderr","text":"WARNING:root:Patching torch.distributed state to support multithreading.\nWARNING:root:torch.distributed support on TPU v2 and v3 is experimental and does not support torchrun.\n[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n","output_type":"stream"},{"name":"stdout","text":"1 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n4 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n6 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n7 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n0 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n5 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n3 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n2 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n","output_type":"stream"}]},{"cell_type":"markdown","source":"For a more in-depth look at PyTorch/XLA, see our [API guide](https://github.com/pytorch/xla/blob/master/API_GUIDE.md).","metadata":{}}]} \ No newline at end of file +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Distributed PyTorch/XLA Basics\n", + "\n", + "Beginning with PyTorch/XLA 2.0, Kaggle supports the new PJRT preview runtime on TPU VMs! For more information about PJRT, see the [PyTorch/XLA GitHub repository](https://github.com/pytorch/xla/blob/master/docs/pjrt.md).\n", + "\n", + "PyTorch/XLA is a package that lets PyTorch run on TPU devices. Kaggle provides a free v3-8 TPU VM. v3-8 TPUs have 8 logical devices: 4 TPU chips, each having 2 cores. This notebook shows how to run simple distributed operations on a TPU using the PJRT runtime. For more information about the Cloud TPU architecture, [see the official documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).\n", + "\n", + "At the time of writing Kaggle Notebooks on TPU VM are preinstalled with Python 3.10 and PT/XLA 2.1. See below for the exact versions." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:01.251900Z", + "iopub.status.busy": "2024-01-10T19:30:01.251567Z", + "iopub.status.idle": "2024-01-10T19:30:01.378200Z", + "shell.execute_reply": "2024-01-10T19:30:01.377121Z", + "shell.execute_reply.started": "2024-01-10T19:30:01.251872Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python 3.10.13\n" + ] + } + ], + "source": [ + "!python --version" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:01.380390Z", + "iopub.status.busy": "2024-01-10T19:30:01.380122Z", + "iopub.status.idle": "2024-01-10T19:30:22.624453Z", + "shell.execute_reply": "2024-01-10T19:30:22.623753Z", + "shell.execute_reply.started": "2024-01-10T19:30:01.380364Z" + }, + "trusted": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.1.0+cu121'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "torch.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:22.625785Z", + "iopub.status.busy": "2024-01-10T19:30:22.625450Z", + "iopub.status.idle": "2024-01-10T19:30:28.439813Z", + "shell.execute_reply": "2024-01-10T19:30:28.439042Z", + "shell.execute_reply.started": "2024-01-10T19:30:22.625759Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "data": { + "text/plain": [ + "'2.1.0+libtpu'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch_xla\n", + "torch_xla.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Unlike JAX or TensorFlow, the convention in PyTorch is to start a separate child process per device to minimize the impact of Python's [Global Interpreter Lock](https://en.wikipedia.org/wiki/Global_interpreter_lock). In eager PyTorch, this means spawning one child process per GPU. For more information, see [PyTorch's distributed training documentation](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#comparison-between-dataparallel-and-distributeddataparallel).\n", + "\n", + "Due to architectural constraints, it is not possible for more than one process to access a TPU chip simultaneously. Because TPU v3 has two TensorCores cores per TPU chip, that means that each process must drive at least two TPU cores. By default, PyTorch/XLA will spawn 4 processes in total (one per chip), each having two threads (one per TensorCore). This is handled transparently by `xmp.spawn`, which mirrors `mp.spawn`. However, it is important to keep in mind that _all distributed workloads on a TPU v2 or v3 are multithreaded_. The function you pass to `spawn` should be thread-safe.\n", + "\n", + "TPU v4 has a different architecture, where each TPU chip is represented to PyTorch as a single device, so we spawn one process per device as expected.\n", + "\n", + "See the [Cloud TPU documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) for an in-depth look at TPU architecture." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "!printenv" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:28.574793Z", + "iopub.status.busy": "2024-01-10T19:30:28.574526Z", + "iopub.status.idle": "2024-01-10T19:30:28.580178Z", + "shell.execute_reply": "2024-01-10T19:30:28.579554Z", + "shell.execute_reply.started": "2024-01-10T19:30:28.574764Z" + }, + "trusted": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'0'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Temporary hack: remove some TPU environment variables to support multiprocessing\n", + "# These will be set later by xmp.spawn.\n", + "\n", + "import os\n", + "os.environ.pop('TPU_PROCESS_ADDRESSES')\n", + "os.environ.pop('CLOUD_TPU_TASK_ID')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:28.581455Z", + "iopub.status.busy": "2024-01-10T19:30:28.581117Z", + "iopub.status.idle": "2024-01-10T19:30:28.606274Z", + "shell.execute_reply": "2024-01-10T19:30:28.605671Z", + "shell.execute_reply.started": "2024-01-10T19:30:28.581429Z" + }, + "trusted": true + }, + "outputs": [], + "source": [ + "import torch_xla.core.xla_model as xm\n", + "import torch_xla.distributed.xla_multiprocessing as xmp" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To get the current process/thread's default XLA device, use `xm.xla_device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:28.607393Z", + "iopub.status.busy": "2024-01-10T19:30:28.607138Z", + "iopub.status.idle": "2024-01-10T19:30:28.664032Z", + "shell.execute_reply": "2024-01-10T19:30:28.662583Z", + "shell.execute_reply.started": "2024-01-10T19:30:28.607368Z" + }, + "trusted": true + }, + "outputs": [], + "source": [ + "import multiprocessing as mp\n", + "lock = mp.Manager().Lock()\n", + "\n", + "def print_device(i, lock):\n", + " device = xm.xla_device()\n", + " with lock:\n", + " print('process', i, device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To run a function on each TPU device, pass it to `xmp.spawn`. We'll use an `mp.Lock` to prevent `print` statements from overlapping between processes. This make the output clearer, but it is optional.\n", + "\n", + "Note: in interactive notebooks, you must use `start_method='fork'`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:28.666657Z", + "iopub.status.busy": "2024-01-10T19:30:28.666339Z", + "iopub.status.idle": "2024-01-10T19:30:33.218095Z", + "shell.execute_reply": "2024-01-10T19:30:33.216950Z", + "shell.execute_reply.started": "2024-01-10T19:30:28.666621Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "process 4 xla:0\n", + "process 0 xla:0\n", + "process 1 xla:1\n", + "process 6 xla:0\n", + "process 7 xla:1\n", + "process 5 xla:1\n", + "process 2 xla:0\n", + "process 3 xla:1\n" + ] + } + ], + "source": [ + "xmp.spawn(print_device, args=(lock,), start_method='fork')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: ignore the errors from `oauth2_credentials.cc`. These will be fixed in a future release." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To run `torch` operations on a TPU, pass the corresponding XLA device in as the `device` parameter. When you pass in an XLA device, the operation is added to a graph, which is executed lazily as needed. To force all devices to evaluate the current graph, call `xm.mark_step()`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:33.219878Z", + "iopub.status.busy": "2024-01-10T19:30:33.219569Z", + "iopub.status.idle": "2024-01-10T19:30:35.653084Z", + "shell.execute_reply": "2024-01-10T19:30:35.651887Z", + "shell.execute_reply.started": "2024-01-10T19:30:33.219847Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6 tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]], device='xla:0')\n", + "7 tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]], device='xla:1')\n", + "0 tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]], device='xla:0')\n", + "2 tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]], device='xla:0')\n", + "3 tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]], device='xla:1')\n", + "1 tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]], device='xla:1')\n", + "4 tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]], device='xla:0')\n", + "5 tensor([[2., 2., 2.],\n", + " [2., 2., 2.],\n", + " [2., 2., 2.]], device='xla:1')\n" + ] + } + ], + "source": [ + "def add_ones(i, lock):\n", + " x = torch.ones((3, 3), device=xm.xla_device())\n", + " y = x + x\n", + " \n", + " # Run graph to compute `y` before printing\n", + " xm.mark_step()\n", + " \n", + " with lock:\n", + " print(i, y)\n", + "\n", + "xmp.spawn(add_ones, args=(lock,), start_method='fork')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To communicate tensors between TPU devices, use the collective communication operations in `xla_model`, such as `all_gather`.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:35.656796Z", + "iopub.status.busy": "2024-01-10T19:30:35.656377Z", + "iopub.status.idle": "2024-01-10T19:30:38.314318Z", + "shell.execute_reply": "2024-01-10T19:30:38.313118Z", + "shell.execute_reply.started": "2024-01-10T19:30:35.656763Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7 tensor([7], device='xla:1')\n", + "6 tensor([6], device='xla:0')\n", + "4 tensor([4], device='xla:0')\n", + "5 tensor([5], device='xla:1')\n", + "3 tensor([3], device='xla:1')\n", + "2 tensor([2], device='xla:0')\n", + "1 tensor([1], device='xla:1')\n", + "0 tensor([0], device='xla:0')\n", + "7 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n", + "6 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n", + "5 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n", + "4 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n", + "3 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n", + "2 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n", + "0 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n", + "1 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n" + ] + } + ], + "source": [ + "def gather_ids(i, lock):\n", + " # Create a tensor on each device with the device ID\n", + " t = torch.tensor([i], device=xm.xla_device())\n", + " with lock:\n", + " print(i, t)\n", + " \n", + " # Collect and concatenate the IDs\n", + " ts = xm.all_gather(t)\n", + " xm.mark_step()\n", + " with lock:\n", + " print(i, ts)\n", + "\n", + "xmp.spawn(gather_ids, args=(lock,), start_method='fork')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "PyTorch/XLA 2.0 also ships with experimental support for the `torch.distributed` using the `pjrt://` `init_method`, including `DistributedDataParallel`.\n", + "\n", + "Because replicas are run multithreaded, the distributed function must be thread-safe. However, the global RNG that `torch` uses for module initialization will give inconsistent results between replicas on TPU v3, since there will be multiple threads concurrently using it. To ensure consistent parameters, we recommend broadcasting model parameters from replica 0 to the other replicas using `pjrt.broadcast_master_param`. In practice, you may also load each replica's parameters from a common checkpoint." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2024-01-10T19:30:38.315927Z", + "iopub.status.busy": "2024-01-10T19:30:38.315653Z", + "iopub.status.idle": "2024-01-10T19:30:43.491104Z", + "shell.execute_reply": "2024-01-10T19:30:43.490078Z", + "shell.execute_reply.started": "2024-01-10T19:30:38.315899Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Patching torch.distributed state to support multithreading.\n", + "WARNING:root:torch.distributed support on TPU v2 and v3 is experimental and does not support torchrun.\n", + "[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n", + "[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n", + "[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n", + "[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n", + "[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n", + "[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n", + "[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n", + "4 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n", + "6 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n", + "7 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n", + "0 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n", + "5 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n", + "3 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n", + "2 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n" + ] + } + ], + "source": [ + "import torch.distributed as dist\n", + "import torch.nn as nn\n", + "from torch.nn.parallel import DistributedDataParallel as DDP\n", + "import torch.optim as optim\n", + "\n", + "import torch_xla.distributed.xla_backend # Registers `xla://` init_method\n", + "import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n", + "\n", + "def toy_model(index, lock):\n", + " device = xm.xla_device()\n", + " dist.init_process_group('xla', init_method='xla://')\n", + "\n", + " # Initialize a basic toy model\n", + " torch.manual_seed(42)\n", + " model = nn.Linear(128, 10).to(device)\n", + "\n", + " # Optional for TPU v4 and GPU\n", + " xm.broadcast_master_param(model)\n", + "\n", + " model = DDP(model)\n", + "\n", + " loss_fn = nn.MSELoss()\n", + " optimizer = optim.SGD(model.parameters(), lr=.001)\n", + "\n", + " for i in range(10):\n", + " # Generate random inputs and outputs on the XLA device\n", + " data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)\n", + "\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = loss_fn(output, target)\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " \n", + " # Run the pending graph\n", + " xm.mark_step()\n", + "\n", + " with lock:\n", + " # Print mean parameters so we can confirm they're the same across replicas\n", + " print(index, [p.mean() for p in model.parameters()])\n", + "\n", + "xmp.spawn(toy_model, args=(lock,), start_method='fork')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For a more in-depth look at PyTorch/XLA, see our [API guide](https://github.com/pytorch/xla/blob/master/API_GUIDE.md)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/source/learn/pjrt.md b/docs/source/learn/pjrt.md index 6fc84bf9de3..97f358f29e5 100644 --- a/docs/source/learn/pjrt.md +++ b/docs/source/learn/pjrt.md @@ -82,7 +82,7 @@ def _mp_fn(index): + # Optional for TPU v4 and GPU + xm.broadcast_master_param(model) - model = DDP(model, gradient_as_bucket_view=True) + model = DDP(model) loss_fn = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=.001) diff --git a/docs/source/perf/ddp.md b/docs/source/perf/ddp.md index a84946cc728..3319a47c2f2 100644 --- a/docs/source/perf/ddp.md +++ b/docs/source/perf/ddp.md @@ -40,10 +40,10 @@ device](../API_GUIDE.md#running-on-a-single-xla-device). world_size = xr.world_size() ``` -4. Pass `gradient_as_bucket_view=True` to the DDP wrapper. +4. Wrap the model with DDP. ``` python - ddp_model = DDP(model, gradient_as_bucket_view=True) + ddp_model = DDP(model) ``` 5. Finally launch your model with xla specific launcher. @@ -107,8 +107,7 @@ def demo_basic(rank): # create model and move it to XLA device device = xm.xla_device() model = ToyModel().to(device) - # currently, graident_as_bucket_view is needed to make DDP work for xla - ddp_model = DDP(model, gradient_as_bucket_view=True) + ddp_model = DDP(model) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) @@ -246,6 +245,6 @@ the native xla data parallel approach, here is the [tutorial](../API_GUIDE.md#running-on-multiple-xla-devices-with-multi-processing). Here are some of the known issues that are under investigation: \* -`gradient_as_bucket_view=True` needs to be enforced. \* There are some +`gradient_as_bucket_view=False` needs to be enforced. \* There are some issues while being used with `torch.utils.data.DataLoader`. `test_train_mp_mnist.py` with real data crashes before exiting. diff --git a/examples/data_parallel/train_resnet_ddp.py b/examples/data_parallel/train_resnet_ddp.py index d5f8da4a9a7..327b3f8cbbc 100644 --- a/examples/data_parallel/train_resnet_ddp.py +++ b/examples/data_parallel/train_resnet_ddp.py @@ -18,7 +18,7 @@ def __init__(self): dist.init_process_group('xla', init_method='xla://') super().__init__() self.model = DDP( - self.model, gradient_as_bucket_view=True, broadcast_buffers=False) + self.model, broadcast_buffers=False) self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4) diff --git a/test/distributed_util.py b/test/distributed_util.py index 2023e72e56b..4d428e202da 100644 --- a/test/distributed_util.py +++ b/test/distributed_util.py @@ -111,18 +111,15 @@ def ddp_correctness(init_method: str = 'env://', steps = 5 # To save test time. cpu_model = LargeNet() - # TODO(@alanwaketan): Investigate whether we can omit the gradient_as_bucket_view option. + # TODO: There're issues in the captured graph when gradient_as_bucket_view is True # bucket_cap_mb is set to 1 mb such that we can still have multiple all_reduces while avoiding # using models that are too larger (25 mb). # To be noted, DDP currently uses one bucket for the first iteration. See pytorch#73732. - ddp_model = DDP( - copy.deepcopy(cpu_model).to(device), - gradient_as_bucket_view=True, - bucket_cap_mb=1) + ddp_model = DDP(copy.deepcopy(cpu_model).to(device), bucket_cap_mb=1) # ddp_model.register_comm_hook(state=None, hook=comp_hook) - cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-4) - ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-4) + cpu_optimizer = optim.SGD(cpu_model.parameters(), lr=1e-1) + ddp_optimizer = optim.SGD(ddp_model.parameters(), lr=1e-1) loss_fn = nn.MSELoss() local_batch_size = 2 diff --git a/test/test_inplace_update.py b/test/test_inplace_update.py new file mode 100644 index 00000000000..d4fa3a44c7a --- /dev/null +++ b/test/test_inplace_update.py @@ -0,0 +1,76 @@ +import io +import sys +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from test_utils import temporary_env + + +class InplaceUpdateTest(unittest.TestCase): + + def test_aten_op_after_full_update(self): + device = xm.xla_device() + t = torch.ones(2, 1, device=device) + w = torch.ones(1, 2, device=device) + t.zero_() + y = torch.matmul(t, w) + expected = torch.zeros(2, 2, device=device) + xm.mark_step() + self.assertTrue(torch.all(torch.eq(y, expected))) + + def test_aten_op_after_partial_update(self): + device = xm.xla_device() + t = torch.ones(2, 1, device=device) + w = torch.ones(1, 2, device=device) + t[0][0] = 0 + y = torch.matmul(t, w) + expected = torch.tensor([[0, 0], [1, 1]], device=device) + xm.mark_step() + self.assertTrue(torch.all(torch.eq(y, expected))) + + def test_non_aten_op_after_full_update(self): + device = xm.xla_device() + t = torch.ones(2, 1, device=device) + w = torch.ones(1, 2, device=device) + t.zero_() + y = torch_xla._XLAC._xla_dot_general(t, w, (([1], [0]), ())) + expected = torch.zeros(2, 2, device=device) + xm.mark_step() + self.assertTrue(torch.all(torch.eq(y, expected))) + + def test_non_aten_op_after_partial_update(self): + device = xm.xla_device() + t = torch.ones(2, 1, device=device) + w = torch.ones(1, 2, device=device) + t[0][0] = 0 + y = torch_xla._XLAC._xla_dot_general(t, w, (([1], [0]), ())) + expected = torch.tensor([[0, 0], [1, 1]], device=device) + xm.mark_step() + self.assertTrue(torch.all(torch.eq(y, expected))) + + def test_xm_save(self): + with temporary_env( + XLA_DISABLE_FUNCTIONALIZATION="0", XLA_ENABLE_PARAM_ALIASING="0"): + xla_device = xm.xla_device() + t1 = torch.tensor([1], device=xla_device) + t2 = t1.detach() + xm.mark_step() + + t2.add_(t2) + xm.mark_step() + + # mark_step() causes t1 and t2 to be out of sync on the XLA side. + + fobj = io.BytesIO() + xm.save({'t1': t1}, fobj) + fobj.seek(0) + saved = torch.load(fobj) + + self.assertEqual(t1.item(), saved['t1'].item()) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_train_mp_imagenet.py b/test/test_train_mp_imagenet.py index fb7db89693a..cc761c875e7 100644 --- a/test/test_train_mp_imagenet.py +++ b/test/test_train_mp_imagenet.py @@ -258,7 +258,7 @@ def train_imagenet(): xm.broadcast_master_param(model) if FLAGS.ddp: - model = DDP(model, gradient_as_bucket_view=True, broadcast_buffers=False) + model = DDP(model, broadcast_buffers=False) writer = None if xm.is_master_ordinal(): diff --git a/test/test_train_mp_mnist.py b/test/test_train_mp_mnist.py index 9e470719f27..315f4b200f6 100644 --- a/test/test_train_mp_mnist.py +++ b/test/test_train_mp_mnist.py @@ -138,7 +138,7 @@ def train_mnist(flags, **kwargs): xm.broadcast_master_param(model) if flags.ddp: - model = DDP(model, gradient_as_bucket_view=True) + model = DDP(model) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(flags.logdir) diff --git a/test/test_utils.py b/test/test_utils.py index 4aefdce6805..ad00a1def62 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,5 @@ import collections +from contextlib import contextmanager import itertools import math import os @@ -390,3 +391,32 @@ def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): results = xu.as_list(fn(*tensors)) xla_results = xu.as_list(fn(*xla_tensors)) self.compareResults(results, xla_results, rel_err=rel_err, abs_err=abs_err) + + +@contextmanager +def temporary_env(**kwargs): + """ + Temporarily set environment variables within the context. + + Args: + **kwargs: Key-value pairs representing environment variables to set. + For example: temporary_env(PATH='/new/path', DEBUG='1') + """ + original_env = {} + + # Store original values and set new ones + for key, value in kwargs.items(): + original_env[key] = os.environ.get(key, None) + os.environ[key] = value + + try: + yield + finally: + # Restore original environment variables + for key, old_value in original_env.items(): + if old_value is None: + # The variable was not originally set + del os.environ[key] + else: + # Restore the original value + os.environ[key] = old_value diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 8c94b2d9c5a..6344aa5d1e5 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -80,6 +80,11 @@ XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) { } // namespace XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) { + if (tensor.defined() && + at::functionalization::impl::isFunctionalTensor(tensor)) { + // To make sure we have the most updated version of tensor. + at::functionalization::impl::sync(tensor); + } XLATensorImpl* impl = GetXlaTensorImpl(tensor); if (impl == nullptr) { return XLATensorPtr();