diff --git a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb index 48c7d06b4ee..5a517565847 100644 --- a/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb +++ b/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb @@ -1 +1 @@ -{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.8.16","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"source":"\"Kaggle\"","metadata":{},"cell_type":"markdown"},{"cell_type":"markdown","source":"# Distributed PyTorch/XLA Basics with PJRT (Beta)\n\nBeginning with PyTorch/XLA 2.0, Kaggle supports the new PJRT preview runtime on TPU VMs! For more information about PJRT, see the [PyTorch/XLA GitHub repository](https://github.com/pytorch/xla/blob/master/docs/pjrt.md).\n\nPyTorch/XLA is a package that lets PyTorch run on TPU devices. Kaggle provides a free v3-8 TPU VM. v3-8 TPUs have 8 logical devices: 4 TPU chips, each having 2 cores. This notebook shows how to run simple distributed operations on a TPU using the PJRT runtime. For more information about the Cloud TPU architecture, [see the official documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).\n\nAt the time of writing Kaggle Notebooks on TPU VM are preinstalled with Python 3.8 and PT/XLA 2.0. See below for the exact versions.","metadata":{}},{"cell_type":"code","source":"!python --version","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:35:49.092292Z","iopub.execute_input":"2023-04-06T21:35:49.092669Z","iopub.status.idle":"2023-04-06T21:35:49.84455Z","shell.execute_reply.started":"2023-04-06T21:35:49.092628Z","shell.execute_reply":"2023-04-06T21:35:49.843514Z"},"trusted":true},"execution_count":1,"outputs":[{"name":"stdout","text":"Python 3.8.16\n","output_type":"stream"}]},{"cell_type":"code","source":"import torch\ntorch.__version__","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:35:49.846433Z","iopub.execute_input":"2023-04-06T21:35:49.846774Z","iopub.status.idle":"2023-04-06T21:35:51.556215Z","shell.execute_reply.started":"2023-04-06T21:35:49.846732Z","shell.execute_reply":"2023-04-06T21:35:51.555445Z"},"trusted":true},"execution_count":2,"outputs":[{"execution_count":2,"output_type":"execute_result","data":{"text/plain":"'2.0.0+cu117'"},"metadata":{}}]},{"cell_type":"code","source":"import torch_xla\ntorch_xla.__version__","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:35:51.557251Z","iopub.execute_input":"2023-04-06T21:35:51.557641Z","iopub.status.idle":"2023-04-06T21:35:51.833919Z","shell.execute_reply.started":"2023-04-06T21:35:51.557603Z","shell.execute_reply":"2023-04-06T21:35:51.832939Z"},"trusted":true},"execution_count":3,"outputs":[{"execution_count":3,"output_type":"execute_result","data":{"text/plain":"'2.0'"},"metadata":{}}]},{"cell_type":"markdown","source":"Unlike JAX or TensorFlow, the convention in PyTorch is to start a separate child process per device to minimize the impact of Python's [Global Interpreter Lock](https://en.wikipedia.org/wiki/Global_interpreter_lock). In eager PyTorch, this means spawning one child process per GPU. For more information, see [PyTorch's distributed training documentation](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#comparison-between-dataparallel-and-distributeddataparallel).\n\nDue to architectural constraints, it is not possible for more than one process to access a TPU chip simultaneously. Because TPU v3 has two TensorCores cores per TPU chip, that means that each process must drive at least two TPU cores. By default, PyTorch/XLA will spawn 4 processes in total (one per chip), each having two threads (one per TensorCore). This is handled transparently by `xmp.spawn`, which mirrors `mp.spawn`. However, it is important to keep in mind that _all distributed workloads on a TPU v2 or v3 are multithreaded_. The function you pass to `spawn` should be thread-safe.\n\nTPU v4 has a different architecture, where each TPU chip is represented to PyTorch as a single device, so we spawn one process per device as expected.\n\nSee the [Cloud TPU documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) for an in-depth look at TPU architecture.","metadata":{}},{"cell_type":"code","source":"!printenv","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Temporary hack: remove some TPU environment variables to support multiprocessing\n# These will be set later by xmp.spawn.\n\nimport os\nos.environ.pop('TPU_PROCESS_ADDRESSES')\nos.environ.pop('CLOUD_TPU_TASK_ID')","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:35:54.308602Z","iopub.execute_input":"2023-04-06T21:35:54.309589Z","iopub.status.idle":"2023-04-06T21:35:54.316139Z","shell.execute_reply.started":"2023-04-06T21:35:54.30955Z","shell.execute_reply":"2023-04-06T21:35:54.315338Z"},"trusted":true},"execution_count":4,"outputs":[{"execution_count":4,"output_type":"execute_result","data":{"text/plain":"'0'"},"metadata":{}}]},{"cell_type":"code","source":"import torch_xla.core.xla_model as xm\nimport torch_xla.distributed.xla_multiprocessing as xmp","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:35:56.576757Z","iopub.execute_input":"2023-04-06T21:35:56.577723Z","iopub.status.idle":"2023-04-06T21:35:56.686272Z","shell.execute_reply.started":"2023-04-06T21:35:56.577688Z","shell.execute_reply":"2023-04-06T21:35:56.685431Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"markdown","source":"To get the current process/thread's default XLA device, use `xm.xla_device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`.","metadata":{}},{"cell_type":"code","source":"import multiprocessing as mp\nlock = mp.Manager().Lock()\n\ndef print_device(i, lock):\n device = xm.xla_device()\n with lock:\n print('process', i, device)","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:36:20.404838Z","iopub.execute_input":"2023-04-06T21:36:20.405228Z","iopub.status.idle":"2023-04-06T21:36:20.432795Z","shell.execute_reply.started":"2023-04-06T21:36:20.405198Z","shell.execute_reply":"2023-04-06T21:36:20.431372Z"},"trusted":true},"execution_count":7,"outputs":[]},{"cell_type":"markdown","source":"To run a function on each TPU device, pass it to `xmp.spawn`. We'll use an `mp.Lock` to prevent `print` statements from overlapping between processes. This make the output clearer, but it is optional.\n\nNote: in interactive notebooks, you must use `start_method='fork'`.","metadata":{}},{"cell_type":"code","source":"xmp.spawn(print_device, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:36:22.001584Z","iopub.execute_input":"2023-04-06T21:36:22.001973Z","iopub.status.idle":"2023-04-06T21:36:53.002622Z","shell.execute_reply.started":"2023-04-06T21:36:22.001935Z","shell.execute_reply":"2023-04-06T21:36:53.001506Z"},"trusted":true},"execution_count":8,"outputs":[{"name":"stderr","text":"E0406 21:36:48.135562430 24639 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:36:48.135542832+00:00\"}\nE0406 21:36:48.140832291 24683 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:36:48.140810769+00:00\"}\nE0406 21:36:48.140935067 24704 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:36:48.140917739+00:00\"}\nE0406 21:36:48.142011794 24681 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:36:48.141988566+00:00\"}\n","output_type":"stream"},{"name":"stdout","text":"process 4 xla:0\nprocess 5 xla:1\nprocess 6 xla:0\nprocess 7 xla:1\nprocess 2 xla:0\nprocess 3 xla:1\nprocess 1 xla:1\nprocess 0 xla:0\n","output_type":"stream"}]},{"cell_type":"markdown","source":"Note: ignore the errors from `oauth2_credentials.cc`. These will be fixed in a future release.","metadata":{}},{"cell_type":"markdown","source":"To run `torch` operations on a TPU, pass the corresponding XLA device in as the `device` parameter. When you pass in an XLA device, the operation is added to a graph, which is executed lazily as needed. To force all devices to evaluate the current graph, call `xm.mark_step()`.","metadata":{}},{"cell_type":"code","source":"def add_ones(i, lock):\n x = torch.ones((3, 3), device=xm.xla_device())\n y = x + x\n \n # Run graph to compute `y` before printing\n xm.mark_step()\n \n with lock:\n print(i, y)\n\nxmp.spawn(add_ones, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:36:53.004808Z","iopub.execute_input":"2023-04-06T21:36:53.005152Z","iopub.status.idle":"2023-04-06T21:37:22.970766Z","shell.execute_reply.started":"2023-04-06T21:36:53.00512Z","shell.execute_reply":"2023-04-06T21:37:22.969616Z"},"trusted":true},"execution_count":9,"outputs":[{"name":"stderr","text":"E0406 21:37:19.123766014 26177 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:\"2023-04-06T21:37:19.123746276+00:00\", grpc_status:2}\nE0406 21:37:19.123927544 25887 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:37:19.123912017+00:00\"}\nE0406 21:37:19.126139358 26191 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:37:19.126123317+00:00\"}\nE0406 21:37:19.225654426 26703 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:37:19.225637409+00:00\"}\n","output_type":"stream"},{"name":"stdout","text":"5 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n3 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n2 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n7 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n6 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n4 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n1 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n0 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n","output_type":"stream"}]},{"cell_type":"markdown","source":"To communicate tensors between TPU devices, use the collective communication operations in `xla_model`, such as `all_gather`.\n\n","metadata":{}},{"cell_type":"code","source":"def gather_ids(i, lock):\n # Create a tensor on each device with the device ID\n t = torch.tensor([i], device=xm.xla_device())\n with lock:\n print(i, t)\n \n # Collect and concatenate the IDs\n ts = xm.all_gather(t)\n xm.mark_step()\n with lock:\n print(i, ts)\n\nxmp.spawn(gather_ids, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:37:22.972289Z","iopub.execute_input":"2023-04-06T21:37:22.97267Z","iopub.status.idle":"2023-04-06T21:37:52.923194Z","shell.execute_reply.started":"2023-04-06T21:37:22.972625Z","shell.execute_reply":"2023-04-06T21:37:52.921931Z"},"trusted":true},"execution_count":10,"outputs":[{"name":"stderr","text":"E0406 21:37:49.090571956 28094 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:37:49.090549568+00:00\"}\nE0406 21:37:49.090749719 27805 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:37:49.090730711+00:00\"}\nE0406 21:37:49.091487416 27801 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:37:49.091471223+00:00\"}\nE0406 21:37:49.095649359 28100 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:37:49.095618555+00:00\"}\n","output_type":"stream"},{"name":"stdout","text":"0 tensor([0], device='xla:0')\n1 tensor([1], device='xla:1')\n3 tensor([3], device='xla:1')\n2 tensor([2], device='xla:0')\n7 tensor([7], device='xla:1')\n6 tensor([6], device='xla:0')\n5 tensor([5], device='xla:1')\n4 tensor([4], device='xla:0')\n1 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n0 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n2 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n7 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n3 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n6 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n4 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n5 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n","output_type":"stream"}]},{"cell_type":"markdown","source":"PyTorch/XLA 2.0 also ships with experimental support for the `torch.distributed` using the `pjrt://` `init_method`, including `DistributedDataParallel`.\n\nBecause replicas are run multithreaded, the distributed function must be thread-safe. However, the global RNG that `torch` uses for module initialization will give inconsistent results between replicas on TPU v3, since there will be multiple threads concurrently using it. To ensure consistent parameters, we recommend broadcasting model parameters from replica 0 to the other replicas using `pjrt.broadcast_master_param`. In practice, you may also load each replica's parameters from a common checkpoint.","metadata":{}},{"cell_type":"code","source":"import torch.distributed as dist\nimport torch.nn as nn\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nimport torch.optim as optim\n\nimport torch_xla.experimental.pjrt_backend # Registers `pjrt://` init_method\nimport torch_xla.experimental.pjrt as pjrt\n\ndef toy_model(index, lock):\n device = xm.xla_device()\n dist.init_process_group('xla', init_method='pjrt://')\n\n # Initialize a basic toy model\n torch.manual_seed(42)\n model = nn.Linear(128, 10).to(device)\n\n # Optional for TPU v4 and GPU\n pjrt.broadcast_master_param(model)\n\n # `gradient_as_bucket_view=True` required for XLA\n model = DDP(model, gradient_as_bucket_view=True)\n\n loss_fn = nn.MSELoss()\n optimizer = optim.SGD(model.parameters(), lr=.001)\n\n for i in range(10):\n # Generate random inputs and outputs on the XLA device\n data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)\n\n optimizer.zero_grad()\n output = model(data)\n loss = loss_fn(output, target)\n loss.backward()\n\n optimizer.step()\n \n # Run the pending graph\n xm.mark_step()\n\n with lock:\n # Print mean parameters so we can confirm they're the same across replicas\n print(index, [p.mean() for p in model.parameters()])\n\nxmp.spawn(toy_model, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2023-04-06T21:37:52.925732Z","iopub.execute_input":"2023-04-06T21:37:52.926355Z","iopub.status.idle":"2023-04-06T21:38:26.980322Z","shell.execute_reply.started":"2023-04-06T21:37:52.926319Z","shell.execute_reply":"2023-04-06T21:38:26.979135Z"},"trusted":true},"execution_count":11,"outputs":[{"name":"stderr","text":"E0406 21:38:19.064130031 30133 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:\"2023-04-06T21:38:19.064111082+00:00\", grpc_status:2}\nE0406 21:38:19.064175505 30135 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:\"2023-04-06T21:38:19.064159685+00:00\", grpc_status:2}\nE0406 21:38:19.064493148 30134 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:38:19.064475235+00:00\"}\nE0406 21:38:19.065542428 30136 oauth2_credentials.cc:236] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:\"2023-04-06T21:38:19.06552556+00:00\"}\n[W socket.cpp:601] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n","output_type":"stream"},{"name":"stdout","text":"0 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n6 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n4 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n7 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n5 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n1 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n2 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n3 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n","output_type":"stream"}]},{"cell_type":"markdown","source":"For a more in-depth look at PyTorch/XLA, see our [API guide](https://github.com/pytorch/xla/blob/master/API_GUIDE.md).","metadata":{}}]} \ No newline at end of file +{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Distributed PyTorch/XLA Basics\n\nBeginning with PyTorch/XLA 2.0, Kaggle supports the new PJRT preview runtime on TPU VMs! For more information about PJRT, see the [PyTorch/XLA GitHub repository](https://github.com/pytorch/xla/blob/master/docs/pjrt.md).\n\nPyTorch/XLA is a package that lets PyTorch run on TPU devices. Kaggle provides a free v3-8 TPU VM. v3-8 TPUs have 8 logical devices: 4 TPU chips, each having 2 cores. This notebook shows how to run simple distributed operations on a TPU using the PJRT runtime. For more information about the Cloud TPU architecture, [see the official documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).\n\nAt the time of writing Kaggle Notebooks on TPU VM are preinstalled with Python 3.10 and PT/XLA 2.1. See below for the exact versions.","metadata":{}},{"cell_type":"code","source":"!python --version","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:01.251567Z","iopub.execute_input":"2024-01-10T19:30:01.251900Z","iopub.status.idle":"2024-01-10T19:30:01.378200Z","shell.execute_reply.started":"2024-01-10T19:30:01.251872Z","shell.execute_reply":"2024-01-10T19:30:01.377121Z"},"trusted":true},"execution_count":1,"outputs":[{"name":"stdout","text":"Python 3.10.13\n","output_type":"stream"}]},{"cell_type":"code","source":"import torch\ntorch.__version__","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:01.380122Z","iopub.execute_input":"2024-01-10T19:30:01.380390Z","iopub.status.idle":"2024-01-10T19:30:22.624453Z","shell.execute_reply.started":"2024-01-10T19:30:01.380364Z","shell.execute_reply":"2024-01-10T19:30:22.623753Z"},"trusted":true},"execution_count":2,"outputs":[{"execution_count":2,"output_type":"execute_result","data":{"text/plain":"'2.1.0+cu121'"},"metadata":{}}]},{"cell_type":"code","source":"import torch_xla\ntorch_xla.__version__","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:22.625450Z","iopub.execute_input":"2024-01-10T19:30:22.625785Z","iopub.status.idle":"2024-01-10T19:30:28.439813Z","shell.execute_reply.started":"2024-01-10T19:30:22.625759Z","shell.execute_reply":"2024-01-10T19:30:28.439042Z"},"trusted":true},"execution_count":3,"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n","output_type":"stream"},{"execution_count":3,"output_type":"execute_result","data":{"text/plain":"'2.1.0+libtpu'"},"metadata":{}}]},{"cell_type":"markdown","source":"Unlike JAX or TensorFlow, the convention in PyTorch is to start a separate child process per device to minimize the impact of Python's [Global Interpreter Lock](https://en.wikipedia.org/wiki/Global_interpreter_lock). In eager PyTorch, this means spawning one child process per GPU. For more information, see [PyTorch's distributed training documentation](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#comparison-between-dataparallel-and-distributeddataparallel).\n\nDue to architectural constraints, it is not possible for more than one process to access a TPU chip simultaneously. Because TPU v3 has two TensorCores cores per TPU chip, that means that each process must drive at least two TPU cores. By default, PyTorch/XLA will spawn 4 processes in total (one per chip), each having two threads (one per TensorCore). This is handled transparently by `xmp.spawn`, which mirrors `mp.spawn`. However, it is important to keep in mind that _all distributed workloads on a TPU v2 or v3 are multithreaded_. The function you pass to `spawn` should be thread-safe.\n\nTPU v4 has a different architecture, where each TPU chip is represented to PyTorch as a single device, so we spawn one process per device as expected.\n\nSee the [Cloud TPU documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) for an in-depth look at TPU architecture.","metadata":{}},{"cell_type":"code","source":"!printenv","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Temporary hack: remove some TPU environment variables to support multiprocessing\n# These will be set later by xmp.spawn.\n\nimport os\nos.environ.pop('TPU_PROCESS_ADDRESSES')\nos.environ.pop('CLOUD_TPU_TASK_ID')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:28.574526Z","iopub.execute_input":"2024-01-10T19:30:28.574793Z","iopub.status.idle":"2024-01-10T19:30:28.580178Z","shell.execute_reply.started":"2024-01-10T19:30:28.574764Z","shell.execute_reply":"2024-01-10T19:30:28.579554Z"},"trusted":true},"execution_count":5,"outputs":[{"execution_count":5,"output_type":"execute_result","data":{"text/plain":"'0'"},"metadata":{}}]},{"cell_type":"code","source":"import torch_xla.core.xla_model as xm\nimport torch_xla.distributed.xla_multiprocessing as xmp","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:28.581117Z","iopub.execute_input":"2024-01-10T19:30:28.581455Z","iopub.status.idle":"2024-01-10T19:30:28.606274Z","shell.execute_reply.started":"2024-01-10T19:30:28.581429Z","shell.execute_reply":"2024-01-10T19:30:28.605671Z"},"trusted":true},"execution_count":6,"outputs":[]},{"cell_type":"markdown","source":"To get the current process/thread's default XLA device, use `xm.xla_device()`. XLA devices are numbered as `xla:i`, where `i` is the index of the device within the current process. Since each process has two devices on a TPU v3, this will be `xla:0` or `xla:1`.","metadata":{}},{"cell_type":"code","source":"import multiprocessing as mp\nlock = mp.Manager().Lock()\n\ndef print_device(i, lock):\n device = xm.xla_device()\n with lock:\n print('process', i, device)","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:28.607138Z","iopub.execute_input":"2024-01-10T19:30:28.607393Z","iopub.status.idle":"2024-01-10T19:30:28.664032Z","shell.execute_reply.started":"2024-01-10T19:30:28.607368Z","shell.execute_reply":"2024-01-10T19:30:28.662583Z"},"trusted":true},"execution_count":7,"outputs":[]},{"cell_type":"markdown","source":"To run a function on each TPU device, pass it to `xmp.spawn`. We'll use an `mp.Lock` to prevent `print` statements from overlapping between processes. This make the output clearer, but it is optional.\n\nNote: in interactive notebooks, you must use `start_method='fork'`.","metadata":{}},{"cell_type":"code","source":"xmp.spawn(print_device, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:28.666339Z","iopub.execute_input":"2024-01-10T19:30:28.666657Z","iopub.status.idle":"2024-01-10T19:30:33.218095Z","shell.execute_reply.started":"2024-01-10T19:30:28.666621Z","shell.execute_reply":"2024-01-10T19:30:33.216950Z"},"trusted":true},"execution_count":8,"outputs":[{"name":"stdout","text":"process 4 xla:0\nprocess 0 xla:0\nprocess 1 xla:1\nprocess 6 xla:0\nprocess 7 xla:1\nprocess 5 xla:1\nprocess 2 xla:0\nprocess 3 xla:1\n","output_type":"stream"}]},{"cell_type":"markdown","source":"Note: ignore the errors from `oauth2_credentials.cc`. These will be fixed in a future release.","metadata":{}},{"cell_type":"markdown","source":"To run `torch` operations on a TPU, pass the corresponding XLA device in as the `device` parameter. When you pass in an XLA device, the operation is added to a graph, which is executed lazily as needed. To force all devices to evaluate the current graph, call `xm.mark_step()`.","metadata":{}},{"cell_type":"code","source":"def add_ones(i, lock):\n x = torch.ones((3, 3), device=xm.xla_device())\n y = x + x\n \n # Run graph to compute `y` before printing\n xm.mark_step()\n \n with lock:\n print(i, y)\n\nxmp.spawn(add_ones, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:33.219569Z","iopub.execute_input":"2024-01-10T19:30:33.219878Z","iopub.status.idle":"2024-01-10T19:30:35.653084Z","shell.execute_reply.started":"2024-01-10T19:30:33.219847Z","shell.execute_reply":"2024-01-10T19:30:35.651887Z"},"trusted":true},"execution_count":9,"outputs":[{"name":"stdout","text":"6 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n7 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n0 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n2 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n3 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n1 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n4 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:0')\n5 tensor([[2., 2., 2.],\n [2., 2., 2.],\n [2., 2., 2.]], device='xla:1')\n","output_type":"stream"}]},{"cell_type":"markdown","source":"To communicate tensors between TPU devices, use the collective communication operations in `xla_model`, such as `all_gather`.\n\n","metadata":{}},{"cell_type":"code","source":"def gather_ids(i, lock):\n # Create a tensor on each device with the device ID\n t = torch.tensor([i], device=xm.xla_device())\n with lock:\n print(i, t)\n \n # Collect and concatenate the IDs\n ts = xm.all_gather(t)\n xm.mark_step()\n with lock:\n print(i, ts)\n\nxmp.spawn(gather_ids, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:35.656377Z","iopub.execute_input":"2024-01-10T19:30:35.656796Z","iopub.status.idle":"2024-01-10T19:30:38.314318Z","shell.execute_reply.started":"2024-01-10T19:30:35.656763Z","shell.execute_reply":"2024-01-10T19:30:38.313118Z"},"trusted":true},"execution_count":10,"outputs":[{"name":"stdout","text":"7 tensor([7], device='xla:1')\n6 tensor([6], device='xla:0')\n4 tensor([4], device='xla:0')\n5 tensor([5], device='xla:1')\n3 tensor([3], device='xla:1')\n2 tensor([2], device='xla:0')\n1 tensor([1], device='xla:1')\n0 tensor([0], device='xla:0')\n7 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n6 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n5 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n4 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n3 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n2 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n0 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:0')\n1 tensor([0, 1, 2, 3, 4, 5, 6, 7], device='xla:1')\n","output_type":"stream"}]},{"cell_type":"markdown","source":"PyTorch/XLA 2.0 also ships with experimental support for the `torch.distributed` using the `pjrt://` `init_method`, including `DistributedDataParallel`.\n\nBecause replicas are run multithreaded, the distributed function must be thread-safe. However, the global RNG that `torch` uses for module initialization will give inconsistent results between replicas on TPU v3, since there will be multiple threads concurrently using it. To ensure consistent parameters, we recommend broadcasting model parameters from replica 0 to the other replicas using `pjrt.broadcast_master_param`. In practice, you may also load each replica's parameters from a common checkpoint.","metadata":{}},{"cell_type":"code","source":"import torch.distributed as dist\nimport torch.nn as nn\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nimport torch.optim as optim\n\nimport torch_xla.distributed.xla_backend # Registers `xla://` init_method\nimport torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n\ndef toy_model(index, lock):\n device = xm.xla_device()\n dist.init_process_group('xla', init_method='xla://')\n\n # Initialize a basic toy model\n torch.manual_seed(42)\n model = nn.Linear(128, 10).to(device)\n\n # Optional for TPU v4 and GPU\n xm.broadcast_master_param(model)\n\n # `gradient_as_bucket_view=True` required for XLA\n model = DDP(model, gradient_as_bucket_view=True)\n\n loss_fn = nn.MSELoss()\n optimizer = optim.SGD(model.parameters(), lr=.001)\n\n for i in range(10):\n # Generate random inputs and outputs on the XLA device\n data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)\n\n optimizer.zero_grad()\n output = model(data)\n loss = loss_fn(output, target)\n loss.backward()\n\n optimizer.step()\n \n # Run the pending graph\n xm.mark_step()\n\n with lock:\n # Print mean parameters so we can confirm they're the same across replicas\n print(index, [p.mean() for p in model.parameters()])\n\nxmp.spawn(toy_model, args=(lock,), start_method='fork')","metadata":{"execution":{"iopub.status.busy":"2024-01-10T19:30:38.315653Z","iopub.execute_input":"2024-01-10T19:30:38.315927Z","iopub.status.idle":"2024-01-10T19:30:43.491104Z","shell.execute_reply.started":"2024-01-10T19:30:38.315899Z","shell.execute_reply":"2024-01-10T19:30:43.490078Z"},"trusted":true},"execution_count":11,"outputs":[{"name":"stderr","text":"WARNING:root:Patching torch.distributed state to support multithreading.\nWARNING:root:torch.distributed support on TPU v2 and v3 is experimental and does not support torchrun.\n[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n[W socket.cpp:663] [c10d] The client socket has failed to connect to [localhost]:12355 (errno: 99 - Cannot assign requested address).\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n[W logger.cpp:326] Warning: Time stats are currently only collected for CPU and CUDA devices. Please refer to CpuTimer or CudaTimer for how to register timer for other device type. (function operator())\n","output_type":"stream"},{"name":"stdout","text":"1 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n4 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n6 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n7 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n0 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n5 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n3 [tensor(-0.0005, device='xla:1', grad_fn=), tensor(-0.0019, device='xla:1', grad_fn=)]\n2 [tensor(-0.0005, device='xla:0', grad_fn=), tensor(-0.0019, device='xla:0', grad_fn=)]\n","output_type":"stream"}]},{"cell_type":"markdown","source":"For a more in-depth look at PyTorch/XLA, see our [API guide](https://github.com/pytorch/xla/blob/master/API_GUIDE.md).","metadata":{}}]} \ No newline at end of file