Skip to content

Commit

Permalink
Update XLA support (#2964)
Browse files Browse the repository at this point in the history
* Fix initialization and microbatching for TPUs

* add version check for PyTortch XLA >= 2.1
  • Loading branch information
bfontain authored Feb 22, 2024
1 parent ccb5e56 commit c0a9697
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
1 change: 1 addition & 0 deletions composer/devices/device_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DeviceTPU(Device):
More details.
"""

dist_backend = 'xla'
name = 'tpu'

def __init__(self):
Expand Down
5 changes: 5 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,11 @@ def _train_microbatch(self, use_grad_scaling: bool, current_batch_size: int,
microbatch_loss.mul_(microbatch_num_samples / current_batch_size)
microbatch_loss.backward(create_graph=self._backwards_create_graph)

if self.state.device.dist_backend == 'xla':
# For xla devices, the program between any pair of mark_steps() calls is compiled. With out this, the
# microbatching loop is unrolled, drastically increasing compile time.
xm.mark_step()

self.engine.run_event(Event.AFTER_BACKWARD)

# Use microbatch outputs to update training metrics
Expand Down
17 changes: 15 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@
import logging
import os
import pickle
import sys
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union, cast

import torch
import torch.distributed as dist
import torch.utils.data
from packaging import version

from composer.utils.device import get_device, is_hpu_installed
from composer.utils.device import get_device, is_hpu_installed, is_tpu_installed

if is_tpu_installed():
import torch_xla

if TYPE_CHECKING:
from composer.devices import Device
Expand Down Expand Up @@ -534,7 +539,15 @@ def initialize_dist(device: Union[str, Device], timeout: float = 300.0):

dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items())

if dist_env_vars_match_defaults:
if device_obj.dist_backend == 'xla':
if not 'torch_xla' in sys.modules:
raise RuntimeError('PyTorch XLA package not found. In order to use XLA based devices '
'PyTorch XLA must be installed.')
if version.parse(torch_xla.__version__) < version.parse('2.1.0'):
raise RuntimeError(f'PyTorch XLA version must be at least 2.1.0, found {torch_xla.__version__}.')
# XLA initialization requires the init_method to be set
dist.init_process_group(device_obj.dist_backend, init_method='xla://')
elif dist_env_vars_match_defaults:
# Fill in the remaining single-rank variables
os.environ.update(dist_env_var_defaults)
dist.init_process_group(device_obj.dist_backend, store=dist.HashStore(), world_size=1, rank=0)
Expand Down

0 comments on commit c0a9697

Please sign in to comment.