-
+
+
-
+ + + Ecosystem + + ++
+
+
-
+ + + Edge + + ++
+
+
+
+
- + + + +
diff --git a/index.md b/index.md index ee8f42ac852..8b47806393e 100644 --- a/index.md +++ b/index.md @@ -1,5 +1,5 @@ --- layout: docs_redirect title: PyTorch | Redirect -redirect_url: "/xla/release/r2.4/index.html" +redirect_url: "/xla/release/r2.5/index.html" --- diff --git a/release/2.5/_images/ddp_md_mnist_with_real_data.png b/release/2.5/_images/ddp_md_mnist_with_real_data.png new file mode 100644 index 00000000000..f83c5182be6 Binary files /dev/null and b/release/2.5/_images/ddp_md_mnist_with_real_data.png differ diff --git a/release/2.5/_images/spmd_mode.png b/release/2.5/_images/spmd_mode.png new file mode 100644 index 00000000000..dd9b5cc69cc Binary files /dev/null and b/release/2.5/_images/spmd_mode.png differ diff --git a/release/2.5/_images/torchbench_pjrt_vs_xrt.svg b/release/2.5/_images/torchbench_pjrt_vs_xrt.svg new file mode 100644 index 00000000000..effe9b72be8 --- /dev/null +++ b/release/2.5/_images/torchbench_pjrt_vs_xrt.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/release/2.5/_images/torchbench_tfrt_vs_se.svg b/release/2.5/_images/torchbench_tfrt_vs_se.svg new file mode 100644 index 00000000000..161f0433b0a --- /dev/null +++ b/release/2.5/_images/torchbench_tfrt_vs_se.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/release/2.5/_modules/index.html b/release/2.5/_modules/index.html new file mode 100644 index 00000000000..503d30fd780 --- /dev/null +++ b/release/2.5/_modules/index.html @@ -0,0 +1,712 @@ + + + + + + +
+ + + + +
+import contextlib
+import io
+import itertools
+import logging
+import sys
+import re
+import threading
+import time
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Tuple, TypedDict, Union
+import torch
+import torch.distributed._functional_collectives
+from torch.library import Library
+import torch.nn.functional as F
+import torch.optim as optim
+import torch_xla
+from torch_xla import runtime
+import torch_xla.core.xla_env_vars as xenv
+import torch_xla.debug.metrics_saver as ms
+import torch_xla.utils.utils as xu
+import torch_xla.utils.closures as xc
+from torch_xla.distributed.spmd.xla_sharding import ShardingSpec
+import os
+from torch_xla.experimental.deprecation import deprecated
+import torch_xla._internal.utils as _utils
+
+_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
+
+REDUCE_SUM = 'sum'
+REDUCE_MUL = 'mul'
+REDUCE_AND = 'and'
+REDUCE_OR = 'or'
+REDUCE_MIN = 'min'
+REDUCE_MAX = 'max'
+
+_DEVICE_CONTEXTS = dict()
+_DEVICE_CONTEXTS_LOCK = threading.Lock()
+
+XLA_LIB = Library("xla", "DEF")
+
+from . import xla_model as this_module
+
+xrt_world_size = deprecated(this_module, torch_xla.runtime.world_size,
+ 'xrt_world_size() will be removed in release 2.7.')
+get_ordinal = deprecated(
+ this_module, torch_xla.runtime.global_ordinal,
+ 'xla_model.get_ordinal() will be removed in release 2.7.')
+parse_xla_device = deprecated(
+ this_module, _utils.parse_xla_device,
+ 'xla_model.parse_xla_device() will be removed in release 2.7.')
+
+
+class DeviceContext(object):
+
+ def __init__(self, device: Union[str, torch.device]):
+ self.device = device
+
+
+def _get_device_context(
+ device: Optional[Union[str, torch.device]] = None) -> DeviceContext:
+ if device is None:
+ device = torch_xla._XLAC._xla_get_default_device()
+ else:
+ device = str(device)
+ with _DEVICE_CONTEXTS_LOCK:
+ devctx = _DEVICE_CONTEXTS.get(device, None)
+ if devctx is None:
+ devctx = DeviceContext(device)
+ _DEVICE_CONTEXTS[device] = devctx
+ return devctx
+
+
+def is_xla_tensor(tensor: torch.Tensor) -> bool:
+ return tensor.device.type == 'xla'
+
+
+def get_xla_supported_devices(devkind: Optional[str] = None,
+ max_devices: Optional[int] = None) -> List[str]:
+ """Returns a list of supported devices of a given kind.
+
+ Args:
+ devkind (string..., optional): If specified, a device type such as `TPU`,
+ `CUDA`, `CPU`, or name of custom PJRT device.
+ max_devices (int, optional): The maximum number of devices to be returned of
+ that kind.
+
+ Returns:
+ The list of device strings such as ['xla:0', 'xla:1', ...]
+ """
+ # TODO(wcromar): Remove `devkind` after 2.3 release cut. We no longer support
+ # multiple device types.
+ if not devkind:
+ devices = torch_xla._XLAC._xla_get_devices()
+ return [
+ f'xla:{i}'
+ for i, _ in enumerate(devices[:max_devices] if max_devices else devices)
+ ]
+ else:
+ warnings.warn("`devkind` argument is deprecated and will be removed in a "
+ "future release.")
+
+ xla_devices = _DEVICES.value
+ kind_devices = []
+ for i, device in enumerate(xla_devices):
+ if re.match(devkind + r':\d+$', device):
+ kind_devices.append('xla:{}'.format(i))
+ if kind_devices:
+ return kind_devices[:max_devices] if max_devices else kind_devices
+
+
+def get_local_ordinal() -> int:
+ """Retrieves the replication local ordinal of the current thread.
+
+ The local ordinals range from 0 to the number of local devices minus 1.
+
+ Returns:
+ The replication local ordinal of the current thread.
+ """
+ return runtime.local_ordinal()
+
+
+[docs]def is_master_ordinal(local: bool = True) -> bool:
+ """Checks whether the current process is the master ordinal (0).
+
+ Args:
+ local (bool): Whether the local or global master ordinal should be checked.
+ In case of multi-host replication, there is only one global master ordinal
+ (host 0, device 0), while there are NUM_HOSTS local master ordinals.
+ Default: True
+
+ Returns:
+ A boolean indicating whether the current process is the master ordinal.
+ """
+ ordinal = get_local_ordinal() if local else runtime.global_ordinal()
+ return ordinal == 0
+
+
+def master_print(*args: Tuple[Any, ...],
+ fd: TextIO = sys.stdout,
+ local: bool = False,
+ flush: bool = False):
+ if is_master_ordinal(local=local):
+ print(*args, file=fd, flush=flush)
+
+
+[docs]def xla_device(n: Optional[int] = None,
+ devkind: Optional[str] = None) -> torch.device:
+ """Returns a given instance of an XLA device.
+
+ Args:
+ n (int, optional): The specific instance (ordinal) to be returned. If
+ specified, the specific XLA device instance will be returned. Otherwise
+ the first device of `devkind` will be returned.
+ devkind (string..., optional): If specified, device type such as `TPU`,
+ `CUDA`, `CPU`, or custom PJRT device. Deprecated.
+
+ Returns:
+ A `torch.device` with the requested instance.
+ """
+ # When SPMD is enabled, we always return `xla:0` to the user, and
+ # under the hood we use virtual device logic for every xla tensor
+ if xu.check_env_flag('XLA_USE_SPMD'):
+ device = 'xla:0'
+ torch_xla._XLAC._xla_set_default_device(device)
+ return torch.device(device)
+
+ return runtime.xla_device(n, devkind)
+
+
+def _xla_real_device(device: torch.device) -> Any:
+ device_str = str(device)
+ m = re.match(r'xla:(\d+)$', device_str)
+ if not m:
+ raise RuntimeError('Invalid device format: {}'.format(device_str))
+ return _DEVICES.value[int(m.group(1))]
+
+
+def xla_real_devices(devices: Optional[List[torch.device]] = None) -> List[str]:
+ """Returns the real devices' name.
+
+ Args:
+ devices: The list of torch devices such as ['xla:0', 'xla:1'].
+
+ Returns:
+ A list of real devices' name such as ['CUDA:0', 'CUDA:1'].
+ """
+ if not devices:
+ devices = get_xla_supported_devices()
+
+ return [_xla_real_device(device) for device in devices]
+
+
+[docs]def xla_device_hw(device: Union[str, torch.device]) -> str:
+ """Returns the hardware type of the given device.
+
+ Args:
+ device (string or torch.device): The xla device that will be mapped to the
+ real device.
+
+ Returns:
+ A string representation of the hardware type of the given device.
+ """
+ real_device = _xla_real_device(device)
+ return real_device.split(':')[0]
+
+
+def xla_replication_devices(
+ local_devices: Optional[List[torch.device]] = None) -> List[str]:
+ real_devices = xla_real_devices(local_devices)
+ device_types = set()
+ for device in real_devices:
+ xdev = _utils.parse_xla_device(device)
+ device_types.add(xdev[0])
+ if len(device_types) != 1:
+ # No replication if the device set spawns multiple device types.
+ raise RuntimeError(
+ 'Cannot replicate across different device types: devices={}/{}'.format(
+ local_devices, real_devices))
+ device_type = device_types.pop()
+ kind_devices = get_xla_supported_devices()
+ if len(kind_devices) != len(local_devices):
+ # Replication can only happen among all devices of one kind.
+ raise RuntimeError(
+ 'Cannot replicate if number of devices ({}) is different from {}'.
+ format(len(local_devices), len(kind_devices)))
+ replication_devices = []
+ for device in torch_xla._XLAC._xla_get_all_devices():
+ # device is like 'CUDA:0'
+ xdev = _utils.parse_xla_device(device)
+ if not xdev:
+ raise RuntimeError('Invalid device format: {}'.format(device))
+ if xdev[0] == device_type:
+ replication_devices.append(device)
+ sorted_by_ordinal = sorted(
+ replication_devices,
+ key=lambda device: _utils.parse_xla_device(device)[1])
+ return sorted_by_ordinal
+
+
+def unlazy(tensors: List[torch.Tensor]):
+ """Blocks the program until `tensors` are materialized.
+
+ This API is for benchmarking, don't use it in real models.
+
+ Args:
+ tensors: List of `torch.Tensor`s to materialize. For each
+ Tensor `t` in the list, `t.device` must be an `xla` device.
+ """
+ torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True)
+
+
+def set_replication(device: torch.device,
+ devices: Optional[List[torch.device]]):
+ device = str(device)
+ devctx = _get_device_context(device=device)
+ devices = [str(x) for x in devices]
+ if devices:
+ # sample replication_devices: ['CUDA:0', 'CUDA:1', 'CUDA:2', 'CUDA:3']
+ replication_devices = xla_replication_devices(devices)
+ torch_xla._XLAC._xla_set_replication_devices(replication_devices)
+ devctx.device_index = devices.index(device)
+ else:
+ torch_xla._XLAC._xla_set_replication_devices([])
+ devctx.device_index = 0
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, None)
+ torch_xla._XLAC._xla_set_default_device(device)
+
+
+class RateTracker(object):
+
+ def __init__(self, smooth_factor: Optional[float] = None):
+ self._smooth_factor = xu.getenv_as(
+ 'RATE_TRACKER_SMOOTHING', float,
+ 0.4) if smooth_factor is None else smooth_factor
+ self._start_time = time.time()
+ self._partial_time = self._start_time
+ self._partial_count = 0.0
+ self._partial_rate = None
+ self._count = 0.0
+
+ def _update(self, now: float, rate: float):
+ self._partial_count += self._count
+ self._count = 0.0
+ self._partial_time = now
+ self._partial_rate = rate
+
+ def add(self, count: float):
+ self._count += count
+
+ def _smooth(self, current_rate: float) -> float:
+ if self._partial_rate is None:
+ smoothed_rate = current_rate
+ else:
+ smoothed_rate = ((1 - self._smooth_factor) * current_rate +
+ self._smooth_factor * self._partial_rate)
+ return smoothed_rate
+
+ def rate(self):
+ now = time.time()
+ delta = now - self._partial_time
+ report_rate = 0.0
+ if delta > 0:
+ report_rate = self._smooth(self._count / delta)
+ self._update(now, report_rate)
+ return report_rate
+
+ def global_rate(self):
+ delta = time.time() - self._start_time
+ count = self._partial_count + self._count
+ return count / delta if delta > 0 else 0.0
+
+
+class ToXlaTensorArena(object):
+
+ def __init__(self, convert_fn: Callable[[List[torch.Tensor]],
+ List[torch.Tensor]],
+ select_fn: Callable[[torch.Tensor], bool]):
+ self._convert_fn = convert_fn
+ self._select_fn = select_fn
+ self._tensors = []
+
+ def _add(self, tensor: torch.Tensor):
+ self._tensors.append(tensor)
+
+ def _convert(self):
+ self._index = 0
+ if self._tensors:
+ self._converted_tensors = self._convert_fn(self._tensors)
+ else:
+ self._converted_tensors = []
+
+ def _get_converted_tensor(self) -> torch.Tensor:
+ assert self._index < len(self._converted_tensors)
+ new_tensor = self._converted_tensors[self._index]
+ self._index += 1
+ return new_tensor
+
+ def _collect_tensors(self, inputs: Any):
+
+ def collect_fn(value: Any):
+ self._add(value)
+
+ xu.for_each_instance(inputs, lambda x: self._select_fn(x), collect_fn)
+
+ def _replace_tensors(self, inputs: Any):
+
+ def convert_fn(value: Any):
+ return self._get_converted_tensor()
+
+ return xu.for_each_instance_rewrite(inputs, lambda x: self._select_fn(x),
+ convert_fn)
+
+ def transform(self, inputs: Any):
+ self._tensors = []
+ self._collect_tensors(inputs)
+ self._convert()
+ return self._replace_tensors(inputs)
+
+
+def check_view_sharing(obj):
+ tensors = set()
+ aliases = dict()
+
+ def tensor_info(t: torch.Tensor) -> str:
+ return '{}{}'.format(t.dtype, list(t.size()))
+
+ def tensor_id(t: torch.Tensor) -> Tuple[int, str]:
+ if is_xla_tensor(t):
+ return torch_xla._XLAC._xla_get_tensor_id(t), 'xla'
+ return id(t), 'torch'
+
+ def alias_id(t: torch.Tensor) -> Tuple[int, str]:
+ if is_xla_tensor(t):
+ aid = torch_xla._XLAC._xla_get_tensor_view_alias_id(t)
+ return None if aid == 0 else aid, 'xla'
+ return t.storage().data_ptr(), 'torch'
+
+ def check_object(obj):
+ tid = tensor_id(obj)
+ if tid not in tensors:
+ tensors.add(tid)
+ aid = alias_id(obj)
+ if aid[0] is not None:
+ if aid in aliases:
+ oobj = aliases[aid]
+ raise RuntimeError(
+ 'Tensor ID {} ({}) is sharing a view with tensor ID {} ({})'.
+ format(tid, tensor_info(obj), tensor_id(oobj), tensor_info(oobj)))
+ aliases[aid] = obj
+
+ xu.for_each_instance(obj, lambda x: type(x) == torch.Tensor, check_object)
+
+
+def _fetch_gradients(optimizer: optim.Optimizer) -> List[torch.Tensor]:
+ gradients = []
+ for param_group in optimizer.__getstate__()['param_groups']:
+ for group, params in param_group.items():
+ if group == 'params':
+ for p in params:
+ if isinstance(p, torch.Tensor) and p.grad is not None:
+ gradients.append(p.grad.data)
+ return gradients
+
+
+def _get_all_reduce_token() -> Tuple[Any, DeviceContext]:
+ devctx = _get_device_context()
+ token = torch_xla._XLAC._get_all_reduce_token(devctx.device)
+ return token, devctx
+
+
+[docs]def all_reduce(
+ reduce_type: str,
+ inputs: Union[torch.Tensor, List[torch.Tensor]],
+ scale: float = 1.0,
+ groups: Optional[List[List[int]]] = None,
+ pin_layout: bool = True) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """Performs an inplace reduce operation on the input tensor(s).
+
+ Args:
+ reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
+ ``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
+ ``xm.REDUCE_MAX``.
+ inputs: Either a single `torch.Tensor` or a list of `torch.Tensor` to
+ perform the all reduce op to.
+ scale (float): A default scaling value to be applied after the reduce.
+ Default: 1.0
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ If a single `torch.Tensor` is passed, the return value is a `torch.Tensor`
+ holding the reduced value (across the replicas). If a list/tuple is passed,
+ this function performs an inplace all-reduce op on the input tensors, and
+ returns the list/tuple itself.
+ """
+ groups = groups or []
+
+ # No-op if there is only one device
+ if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
+ bool, False):
+ if isinstance(inputs, torch.Tensor):
+ return inputs.clone()
+ else:
+ return inputs
+
+ if isinstance(inputs, torch.Tensor):
+ result = None
+ if scale == 1.0 and groups == [] and pin_layout:
+ # TODO(alanwaketan): Support groups.
+ # Only c10d_functional version cc ops are traceable by Dynamo.
+ result = torch.ops._c10d_functional.all_reduce(inputs, reduce_type, "")
+ else:
+ result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, scale,
+ groups, pin_layout)
+ results = [result]
+ else:
+ torch_xla._XLAC._xla_all_reduce_inplace(reduce_type, inputs, scale, groups,
+ pin_layout)
+ results = inputs
+ return results[0] if isinstance(inputs, torch.Tensor) else results
+
+
+def _all_gather_using_all_reduce(
+ value: torch.Tensor,
+ dim: int = 0,
+ groups: Optional[List[List[int]]] = None,
+ pin_layout: bool = True) -> Optional[torch.Tensor]:
+ """Performs an all-gather operation using all-reduce along a given dimension.
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ dim (int): The gather dimension.
+ Default: 0
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_gather()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ A tensor which has, in the ``dim`` dimension, all the values from the
+ participating replicas.
+ """
+ if dim < 0:
+ dim = value.dim() + dim
+ size = value.size(dim)
+ padding = [0] * (2 * value.dim())
+ ordinal = runtime.global_ordinal()
+ if groups is None:
+ left, right = ordinal, runtime.world_size() - 1 - ordinal
+ else:
+ ordinals = dict()
+ for g in groups:
+ for i, x in enumerate(g):
+ ordinals[x] = (i, len(g) - 1 - i)
+ left, right = ordinals[ordinal]
+ idx = value.dim() - 1 - dim
+ padding[2 * idx] = left * size
+ padding[2 * idx + 1] = right * size
+ return all_reduce(REDUCE_SUM, F.pad(value, padding), groups=groups)
+
+
+[docs]def all_gather(value: torch.Tensor,
+ dim: int = 0,
+ groups: Optional[List[List[int]]] = None,
+ output: Optional[torch.Tensor] = None,
+ pin_layout: bool = True) -> torch.Tensor:
+ """Performs an all-gather operation along a given dimension.
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ dim (int): The gather dimension.
+ Default: 0
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_gather()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ output (torch.Tensor): Optional output tensor.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ A tensor which has, in the ``dim`` dimension, all the values from the
+ participating replicas.
+ """
+ # _all_gather_using_all_reduce does not support list of tensors as input
+ if pin_layout and output == None and isinstance(value, torch.Tensor):
+ # There is not an easy way to pin the all_gather layout, so use all_reduce
+ # based all_gather for this purpose.
+ return _all_gather_using_all_reduce(
+ value, dim=dim, groups=groups, pin_layout=True)
+
+ if dim < 0:
+ dim = value.dim() + dim
+ if groups:
+ shard_count = len(groups[0])
+ assert all(len(group) == shard_count for group in groups), \
+ "Replica groups must have the same number of replicas/shards."
+ else:
+ # All replicas belong to a single group
+ shard_count = runtime.world_size()
+
+ token, devctx = _get_all_reduce_token()
+
+ if isinstance(value, torch.Tensor):
+ if output != None:
+ # Call the out of place version of the all_gather
+ new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
+ shard_count, groups or [],
+ pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return output
+
+ result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
+ [], pin_layout)
+ return result
+
+ # Now the input should be a list of Tensors.
+ elif isinstance(value, list) and all(
+ isinstance(v, torch.Tensor) for v in value):
+ if pin_layout:
+ raise RuntimeError(
+ "For xm.all_gather with list of tensors input, pin_layout=True is not yet supported."
+ )
+ if output != None:
+ if not isinstance(output, list) or any(
+ not isinstance(v, torch.Tensor) for v in output):
+ raise TypeError(
+ f"`output` needs to be a list of Tensors, but given {type(output)}."
+ )
+ if len(output) != len(value):
+ raise ValueError("`output` length doesn't match `input` length: "
+ f"{len(output)} vs {len(input)}.")
+ # Call the out of place version of the reduce_scatter
+ new_token = torch_xla._XLAC._xla_all_gather_coalesced_out(
+ output, value, token, dim, shard_count, groups or [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return output
+
+ result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
+ shard_count, groups or
+ [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
+ return result[:-1]
+ else:
+ raise TypeError("`value` needs to be a Tensor or a list of Tensors, but "
+ f"given {type(value)}.")
+
+
+class CoalescingBuckets(object):
+
+ def __init__(
+ self,
+ func: Callable[[
+ Union[torch.Tensor,
+ List[torch.Tensor]], Optional[Union[torch.Tensor,
+ List[torch.Tensor]]]
+ ], Union[torch.Tensor, List[torch.Tensor]]],
+ input_list: Any,
+ output_list: Optional[Any] = None,
+ bucket_cap_mb: int = 160):
+ if not isinstance(input_list, list) or any(
+ not isinstance(v, torch.Tensor) for v in input_list):
+ raise TypeError(
+ f"`input_list` needs to be a list of Tensors, but given {type(input_list)}."
+ )
+ if output_list != None:
+ if not isinstance(output_list, list) or any(
+ not isinstance(v, torch.Tensor) for v in output_list):
+ raise TypeError(
+ f"`output_list` needs to be a list of Tensors, but given {type(output_list)}."
+ )
+ if len(output_list) != len(input_list):
+ raise ValueError(
+ "`output_list` length doesn't match `input_list` length: "
+ f"{len(output_list)} vs {len(input_list)}.")
+ self._func = func
+ self._input_list = input_list
+ self._output_list = output_list
+ self._total = 0
+ self._tensor_bucket = []
+ self._output_bucket = [] if output_list else None
+ self._bucket_cap = bucket_cap_mb * 1024 * 1024
+ self._out_tensors = []
+
+ def flush(self):
+ if len(self._tensor_bucket) == 1:
+ # Use non-coalesced CCOp if its just one tensor
+ output = self._output_bucket[0] if self._output_bucket else None
+ self._out_tensors.append(self._func(self._tensor_bucket[0], output))
+ elif len(self._tensor_bucket):
+ self._out_tensors.extend(
+ self._func(self._tensor_bucket, self._output_bucket))
+ self._total = 0
+ self._tensor_bucket = []
+ self._output_bucket = [] if self._output_list else None
+
+ def add(self, tensor: torch.Tensor, idx: int):
+ self._total += tensor.numel() * tensor.element_size()
+ self._tensor_bucket.append(tensor)
+ if self._output_list != None:
+ self._output_bucket.append(self._output_list[idx])
+
+ def __call__(self) -> Union[torch.Tensor, List[torch.Tensor]]:
+ for idx, tensor in enumerate(self._input_list):
+ tensor_bytes = tensor.numel() * tensor.element_size()
+
+ # Aim for target bucket_cap_mb: flush new tensor with bucket if bucket content
+ # is small (1/2 cap) but don't combine if combined total is over 2x cap
+ total_new = self._total + tensor_bytes
+ if tensor_bytes > self._bucket_cap and self._total < 0.5 * self._bucket_cap and total_new <= 2 * self._bucket_cap:
+ self.add(tensor, idx)
+ self.flush()
+ else:
+ # Bucketize till the total spills over
+ if total_new > self._bucket_cap:
+ self.flush()
+ self.add(tensor, idx)
+
+ # Flush the last remaining bucket
+ self.flush()
+
+ assert len(self._out_tensors) == len(self._input_list)
+
+ return self._out_tensors
+
+
+def all_gather_bucketized(
+ input_list: List[torch.Tensor],
+ dim: int = 0,
+ groups: Optional[List[List[int]]] = None,
+ output: Optional[torch.Tensor] = None,
+ pin_layout: bool = False,
+ bucket_cap_mb=160) -> Union[torch.Tensor, List[torch.Tensor]]:
+ """Performs an all-gather operation along a given dimension, with bucketization.
+
+ Args:
+ See all_gather for the args: dim, groups, output, pin_layout
+ input_list: List of input tensors
+ bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather.
+
+ Returns:
+ A list of tensors each of which has, in the ``dim`` dimension, all the values from the
+ participating replicas.
+ """
+ # sanity checks
+ if pin_layout:
+ raise RuntimeError(
+ "For xm.all_gather_bucketized, pin_layout=True is not yet supported.")
+
+ def _all_gather_coalesced(_input_list, _output_list=None):
+ return all_gather(
+ value=_input_list,
+ dim=dim,
+ groups=groups,
+ output=_output_list,
+ pin_layout=pin_layout)
+
+ buckets = CoalescingBuckets(
+ _all_gather_coalesced, input_list, output, bucket_cap_mb=bucket_cap_mb)
+ return buckets()
+
+
+[docs]def all_to_all(value: torch.Tensor,
+ split_dimension: int,
+ concat_dimension: int,
+ split_count: int,
+ groups: Optional[List[List[int]]] = None,
+ pin_layout: bool = True) -> torch.Tensor:
+ """Performs an XLA `AllToAll()` operation on the input tensor.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#alltoall
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ split_dimension (int): The dimension upon which the split should happen.
+ concat_dimension (int): The dimension upon which the concat should happen.
+ split_count (int): The split count.
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ The result `torch.Tensor` of the `all_to_all()` operation.
+ """
+ token, devctx = _get_all_reduce_token()
+ result = torch_xla._XLAC._xla_all_to_all(value, token, split_dimension,
+ concat_dimension, split_count,
+ groups or [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
+ return result[0]
+
+
+def collective_permute(value: torch.Tensor,
+ pairs: List[List[int]]) -> torch.Tensor:
+ """Performs a XLA `CollectivePermute()` operation on the input tensor.
+
+ WARNING: This function is not very reliable, may produce wrong results under
+ certain inputs. Use it at your own risk.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ pairs (list): A list of (source_replica_id, target_replica_id) pairs,
+ representing the sender and receiver for the `collective_permute()`
+ operation. Example: `[[0, 1], [1, 2], [2, 0]]` defines three pairs. The
+ tensor will be sent from replica 0 to replica 1, replica 1 to replica 2,
+ and replica 2 to replica 0.
+
+ Returns:
+ The result `torch.Tensor` of the `collective_permute()` operation.
+ """
+ token, devctx = _get_all_reduce_token()
+ result = torch_xla._XLAC._xla_collective_permute(value, token, pairs)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
+ return result[0]
+
+
+def collective_broadcast(tensors: List[torch.Tensor],
+ root_ordinal: int = 0,
+ groups: Optional[List[int]] = None,
+ pin_layout: bool = True) -> None:
+ """Broadcast values of `tensors` from root replica to other replicas in-place.
+
+ Args:
+ tensors (list): List of `torch.Tensor`s to broadcast.
+ root_ordinal (int): Ordinal of replica with values to broadcast.
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+ """
+ with torch.no_grad():
+ # We must produce the exact same graph in each replica to prevent hanging,
+ # so each replica must have the same multiply op with the same parameters.
+ for tensor in tensors:
+ scale = torch.tensor(
+ 1 if runtime.global_ordinal() == root_ordinal else 0,
+ dtype=tensor.dtype)
+ # Transfer scale tensor as device data instead of constant 1 or 0.
+ xscale = send_cpu_data_to_device(scale, tensor.device)
+ tensor.mul_(xscale[0])
+
+ all_reduce(REDUCE_SUM, tensors, groups=groups, pin_layout=pin_layout)
+
+
+def send(value: torch.Tensor, channel_id: int) -> torch.Tensor:
+ """Performs a XLA `Send()` operation on the input tensor.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#send
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ channel_id (int64): opaque id identifying the destination of the send op.
+ """
+ token, devctx = _get_all_reduce_token()
+ # The input will be returned as result.
+ input_as_result, new_token = torch_xla._XLAC._xla_send(
+ value, token, channel_id)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return input_as_result
+
+
+def recv(output: torch.Tensor, channel_id: int) -> torch.Tensor:
+ """Performs a XLA `Recv()` operation on the input tensor.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#recv
+
+ Args:
+ output (torch.Tensor): The output tensor.
+ channel_id (int64): opaque id identifying the source of the recv op.
+ """
+ token, devctx = _get_all_reduce_token()
+ result, new_token = torch_xla._XLAC._xla_recv(output, token, channel_id)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return result
+
+
+def reduce_scatter(reduce_type: str,
+ input: Union[torch.Tensor, List[torch.Tensor]],
+ scale: float,
+ scatter_dim: int,
+ shard_count: int,
+ groups: Optional[List[List[int]]] = None,
+ output: Optional[Union[torch.Tensor,
+ List[torch.Tensor]]] = None,
+ pin_layout: bool = True) -> torch.Tensor:
+ """Performs a XLA `ReduceScatter()` operation on the input tensor.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
+
+ Args:
+ reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
+ ``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
+ ``xm.REDUCE_MAX``.
+ input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
+ it will also be the output.
+ scale (float): A default scaling value to be applied after the reduce.
+ scatter_dim (int): Dimension number to which apply scatter operation.
+ shard_count (int): The number of ways to split up the scatter_dim in.
+ groups (list): A list of list, representing the replica groups for
+ the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ output: Optional output tensor if `input` is a torch.Tensor, or a list of
+ torch.Tensor if `input` is a list of torch.Tensor.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ A `torch.Tensor` with all the values reduced across replicas. Each process
+ gets a shard split along the `scatter_dim`. All other dimensions are
+ the same as the input.
+ """
+ token, devctx = _get_all_reduce_token()
+
+ if isinstance(input, torch.Tensor):
+ if output != None:
+ # Call the out of place version of the reduce_scatter
+ new_token = torch_xla._XLAC._xla_reduce_scatter_out(
+ reduce_type, output, input, token, scale, scatter_dim, shard_count,
+ groups or [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return output
+
+ result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
+ scale, scatter_dim,
+ shard_count, groups or [],
+ pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
+ return result[0]
+
+ # Now the input should be a list of Tensors.
+ elif isinstance(input, list) and all(
+ isinstance(v, torch.Tensor) for v in input):
+ if output != None:
+ if not isinstance(output, list) or any(
+ not isinstance(v, torch.Tensor) for v in output):
+ raise TypeError(
+ f"`output` needs to be a list of Tensors, but given {type(output)}."
+ )
+ if len(output) != len(input):
+ raise ValueError("`output` length doesn't match `input` length: "
+ f"{len(output)} vs {len(input)}.")
+ # Call the out of place version of the reduce_scatter
+ new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out(
+ reduce_type, output, input, token, scale, scatter_dim, shard_count,
+ groups or [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return output
+
+ result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
+ reduce_type, input, token, scale, scatter_dim, shard_count, groups or
+ [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
+ return result[:-1]
+ else:
+ raise TypeError("`input` needs to be a Tensor or a list of Tensors, but "
+ f"given {type(input)}.")
+
+
+def reduce_scatter_bucketized(reduce_type: str,
+ input_list: Union[torch.Tensor,
+ List[torch.Tensor]],
+ scale: float,
+ scatter_dim: int,
+ shard_count: int,
+ groups: Optional[List[List[int]]] = None,
+ output: Optional[Union[
+ torch.Tensor, List[torch.Tensor]]] = None,
+ pin_layout: bool = False,
+ bucket_cap_mb: int = 160) -> CoalescingBuckets:
+ """Performs a XLA `ReduceScatter()` operation on a list of tensors (bucketized).
+
+ See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
+
+ Args:
+ see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout
+ input_list: List of input tensors
+ output: Optional list of output torch.Tensor
+ bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing reduce-scatter.
+
+ Returns:
+ A list of `torch.Tensors` with all the values reduced across replicas. Each process
+ gets a shard split along the `scatter_dim`. All other dimensions are
+ the same as the input.
+ """
+
+ def _reduce_scatter_coalesced(
+ _input_list: Union[torch.Tensor, List[torch.Tensor]],
+ _output_list: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ return reduce_scatter(
+ reduce_type=reduce_type,
+ input=_input_list,
+ scale=scale,
+ scatter_dim=scatter_dim,
+ shard_count=shard_count,
+ groups=groups,
+ output=_output_list,
+ pin_layout=pin_layout)
+
+ buckets = CoalescingBuckets(
+ _reduce_scatter_coalesced,
+ input_list,
+ output,
+ bucket_cap_mb=bucket_cap_mb)
+ return buckets()
+
+
+[docs]def add_step_closure(closure: Callable[..., Any],
+ args: Tuple[Any] = (),
+ run_async: bool = False):
+ """Adds a closure to the list of the ones to be run at the end of the step.
+
+ Many times during model training there is the need to print/report (print to
+ console, post to tensorboard, etc...) information which require the content of
+ intermediary tensors to be inspected.
+ Inspecting different tensors content in different points of the model code
+ requires many executions and typically causes performance issues.
+ Adding a step closure will ensure that it will be run after the barrier, when
+ all the live tensors will be already materialized to device data.
+ Live tensors which will include the ones captured by the closure arguments.
+ So using `add_step_closure()` will ensure a single execution will be
+ performed, even when multiple closures are queued, requiring multiple tensors
+ to be inspected.
+ Step closures will be run sequentially in the order they have been queued.
+ Note that even though using this API the execution will be optimized, it is
+ advised to throttle the printing/reporting events once every N steps.
+
+ Args:
+ closure (callable): The function to be called.
+ args (tuple): The arguments to be passed to the closure.
+ run_async: If True, run the closure asynchronously.
+ """
+ devctx = _get_device_context()
+ closures_type = 'async_step_closures' if run_async else 'step_closures'
+ step_closures = getattr(devctx, closures_type, None)
+ if step_closures is None:
+ step_closures = []
+ setattr(devctx, closures_type, step_closures)
+ step_closures.append(lambda a=args: closure(*a))
+
+
+def _run_step_closures() -> DeviceContext:
+ devctx = _get_device_context()
+ async_step_closures = getattr(devctx, 'async_step_closures', None)
+ if async_step_closures is not None:
+ devctx.async_step_closures = []
+ async_closure_handler = getattr(devctx, 'async_closure_handler', None)
+ if async_closure_handler is None:
+ async_closure_handler = xc.AsyncClosureHandler()
+ devctx.async_closure_handler = async_closure_handler
+ async_closure_handler.run_all(async_step_closures)
+
+ step_closures = getattr(devctx, 'step_closures', None)
+ if step_closures is not None:
+ devctx.step_closures = []
+ for closure in step_closures:
+ closure()
+ return devctx
+
+
+def mark_step(wait: bool = False, reset_scope: bool = True):
+ if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
+ print(
+ 'torch_xla.core.xla_model::mark_step\n',
+ end='',
+ file=sys.stderr,
+ flush=True)
+ torch_xla._XLAC._xla_step_marker(
+ torch_xla._XLAC._xla_get_default_device(), [],
+ wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait),
+ reset_scope=reset_scope)
+ # Only emit metrics from the first local device index, to avoid emitting the
+ # same values from different threads.
+ if is_master_ordinal():
+ ms.save_metrics()
+ devctx = _run_step_closures()
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, None)
+
+
+# TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be
+# dump as outputs. Need further investigation.
+[docs]def get_stablehlo(tensors: Optional[List[torch.Tensor]] = None) -> str:
+ """Get StableHLO for the computation graph in string format.
+
+ If `tensors` is not empty, the graph with `tensors` as outputs will be dump.
+ If `tensors` is empty, the whole computation graph will be dump.
+
+ For inference graph, it is recommended to pass the model outputs to `tensors`.
+ For training graph, it is not straightforward to identify the "outputs". Using empty `tensors` is recommended.
+
+ To enable source line info in StableHLO, please set env var XLA_HLO_DEBUG=1.
+
+ Args:
+ tensors (list[torch.Tensor], optional): Tensors that represent the output/root of the StableHLO graph.
+
+ Returns:
+ StableHLO Module in string format.
+ """
+ if tensors is None:
+ tensors = []
+ return torch_xla._XLAC._get_stablehlo(
+ tensors, torch_xla._XLAC._xla_get_default_device(), [],
+ False).decode('utf-8')
+
+
+# TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be
+# dump as outputs. Need further investigation.
+[docs]def get_stablehlo_bytecode(tensors: Optional[torch.Tensor] = None) -> bytes:
+ """Get StableHLO for the computation graph in bytecode format.
+
+ If `tensors` is not empty, the graph with `tensors` as outputs will be dump.
+ If `tensors` is empty, the whole computation graph will be dump.
+
+ For inference graph, it is recommended to pass the model outputs to `tensors`.
+ For training graph, it is not straightforward to identify the "outputs". Using empty `tensors` is recommended.
+
+ Args:
+ tensors (list[torch.Tensor], optional): Tensors that represent the output/root of the StableHLO graph.
+
+ Returns:
+ StableHLO Module in bytecode format.
+ """
+ if tensors is None:
+ tensors = []
+ return torch_xla._XLAC._get_stablehlo(
+ tensors, torch_xla._XLAC._xla_get_default_device(), [], True)
+
+
+[docs]def wait_device_ops(devices: List[str] = []):
+ """Waits for all the async operations on the given devices to complete.
+
+ Args:
+ devices (string..., optional): The devices whose async ops need to be waited
+ for. If empty, all the local devices will be waited for.
+ """
+ torch_xla._XLAC._xla_wait_device_ops(devices=devices)
+
+
+def all_reduce_bucketized_gradients(gradients: List[torch.Tensor],
+ scale: float,
+ groups: Optional[List[List[int]]],
+ pin_layout: bool,
+ bucket_cap_mb: int = 0):
+ total = 0
+ tensor_bucket = []
+ bucket_cap = bucket_cap_mb * 1024 * 1024
+
+ for grad in gradients:
+ grad_bytes = grad.numel() * grad.element_size()
+
+ # Bucketize till the total spills over
+ total += grad_bytes
+ if total > bucket_cap and len(tensor_bucket) > 0:
+ all_reduce(
+ REDUCE_SUM,
+ tensor_bucket,
+ scale=scale,
+ groups=groups,
+ pin_layout=pin_layout)
+ total = grad_bytes
+ tensor_bucket = []
+ tensor_bucket.append(grad)
+
+ # Flush the last remaining bucket
+ if len(tensor_bucket):
+ all_reduce(
+ REDUCE_SUM,
+ tensor_bucket,
+ scale=scale,
+ groups=groups,
+ pin_layout=pin_layout)
+
+
+def reduce_gradients(optimizer: optim.Optimizer,
+ groups: Optional[List[List[int]]] = None,
+ pin_layout: bool = True):
+ """Reduces all the gradients handled by an optimizer.
+
+ Args:
+ optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance
+ containing the gradients to be reduced.
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout when reducing gradients.
+ See `xm.all_reduce` for details.
+ """
+ count = runtime.world_size()
+ if count > 1:
+ gradients = _fetch_gradients(optimizer)
+ bucket_cap_mb = int(os.getenv('ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB', 0))
+ # Reverse the gradients list so that we start allreduce from the last layer
+ # onwards. This allows allreduce to trigger as soon as the bucket fills up and
+ # overlap with backward pass.
+ if bucket_cap_mb > 0:
+ gradients = reversed(gradients)
+ all_reduce_bucketized_gradients(
+ gradients,
+ scale=1.0 / count,
+ groups=groups,
+ pin_layout=pin_layout,
+ bucket_cap_mb=bucket_cap_mb)
+ else:
+ all_reduce(
+ REDUCE_SUM,
+ gradients,
+ scale=1.0 / count,
+ groups=groups,
+ pin_layout=pin_layout)
+
+
+[docs]def optimizer_step(optimizer: optim.Optimizer,
+ barrier: bool = False,
+ optimizer_args: Dict = {},
+ groups: Optional[List[List[int]]] = None,
+ pin_layout: bool = True):
+ """Run the provided optimizer step and sync gradidents across all devices.
+
+ Args:
+ optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance whose
+ `step()` function needs to be called. The `step()` function will be called
+ with the `optimizer_args` named arguments.
+ barrier (bool, optional): Whether the XLA tensor barrier should be issued in
+ this API. If using the PyTorch XLA `ParallelLoader` or `DataParallel`
+ support, this is not necessary as the barrier will be issued by the XLA
+ data loader iterator `next()` call.
+ Default: False
+ optimizer_args (dict, optional): Named arguments dictionary for the
+ `optimizer.step()` call.
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout when reducing gradients.
+ See `xm.all_reduce` for details.
+
+ Returns:
+ The same value returned by the `optimizer.step()` call.
+
+ Example:
+
+ >>> import torch_xla.core.xla_model as xm
+ >>> xm.optimizer_step(self.optimizer)
+ """
+ reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout)
+ loss = optimizer.step(**optimizer_args)
+ if barrier:
+ mark_step()
+ return loss
+
+
+[docs]def save(data: Any,
+ file_or_path: Union[str, TextIO],
+ master_only: bool = True,
+ global_master: bool = False):
+ """Saves the input data into a file.
+
+ The saved data is transferred to PyTorch CPU device before being saved, so a
+ following `torch.load()` will load CPU data.
+ Care must be taken when working with views. Instead of saving views it's
+ recommended that you recreate them after the tensors have been loaded and
+ moved to their destination device(s).
+
+ Args:
+ data: The input data to be saved. Any nested combination of Python objects
+ (list, tuples, sets, dicts, ...).
+ file_or_path: The destination for the data saving operation. Either a file
+ path or a Python file object. If `master_only` is ``False`` the path or
+ file objects must point to different destinations as otherwise all the
+ writes from the same host will override each other.
+ master_only (bool, optional): Whether only the master device should save the
+ data. If False, the `file_or_path` argument should be a different file or
+ path for each of the ordinals taking part to the replication, otherwise
+ all the replicas on the same host will be writing to the same location.
+ Default: True
+ global_master (bool, optional): When ``master_only`` is ``True`` this flag
+ controls whether every host's master (if ``global_master`` is ``False``)
+ saves the content, or only the global master (ordinal 0).
+ Default: False
+
+ Example:
+
+ >>> import torch_xla.core.xla_model as xm
+ >>> xm.wait_device_ops() # wait for all pending operations to finish.
+ >>> xm.save(obj_to_save, path_to_save)
+ >>> xm.rendezvous('torch_xla.core.xla_model.save') # multi process context only
+ """
+ should_write_data = not master_only or is_master_ordinal(
+ local=not global_master)
+
+ cpu_data = _maybe_convert_to_cpu(data, convert=should_write_data)
+ if should_write_data:
+ torch.save(cpu_data, file_or_path)
+
+
+def _maybe_convert_to_cpu(data: Any, convert: bool = True) -> ToXlaTensorArena:
+
+ def convert_fn(tensors):
+ torch_xla._XLAC._xla_sync_multi(
+ tensors, devices=[], wait=True, sync_xla_data=True)
+ if not convert:
+ return tensors
+ return torch_xla._XLAC._xla_get_cpu_tensors(tensors)
+
+ def select_fn(v):
+ return type(v) == torch.Tensor and is_xla_tensor(v)
+
+ return ToXlaTensorArena(convert_fn, select_fn).transform(data)
+
+
+def send_cpu_data_to_device(
+ datas: Any,
+ device: Union[str, torch.device],
+ input_sharding: Optional[ShardingSpec] = None) -> ToXlaTensorArena:
+
+ def convert_fn(tensors):
+ devices = [str(device)] * len(tensors)
+ shardings = None
+ if input_sharding:
+ shardings = [input_sharding.xla_spec(t) for t in tensors]
+ xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
+ shardings)
+ return xtensors
+
+ def select_fn(v):
+ return type(v) == torch.Tensor and v.device.type == 'cpu'
+
+ if type(datas) is torch.Tensor:
+ datas = [datas]
+ return ToXlaTensorArena(convert_fn, select_fn).transform(datas)
+
+
+def xla_rendezvous(payload: bytes = b'',
+ ordinals: Optional[List[int]] = None,
+ tag: Optional[str] = None) -> List[bytes]:
+ """Share `payload` with all replicas in `ordinals`.
+
+ `tag` is ignored except for logging.
+
+ Uses XLA collective communication to communicate between replicas, so this
+ will sync the graph (`xm.mark_step`).
+
+ Args:
+ tag: Name of this rendezvous operation.
+ payload: Payload to share with other replicas.
+ ordinals: List of replicas participating in rendezvous.
+ Returns:
+ List of bytes from other replicas.
+ """
+ if ordinals and len(ordinals) != runtime.global_device_count():
+ raise ValueError('Only global rendezvous is supported')
+
+ if not isinstance(payload, bytes):
+ raise TypeError('`payload` must be bytes, not {}'.format(type(payload)))
+
+ # Finish all execution of previous graphs to avoid recompilation
+ mark_step()
+
+ device = xla_device()
+
+ data = torch.tensor(list(payload), dtype=torch.uint8)
+ size = torch.tensor([data.shape[0]], dtype=torch.int, device=device)
+
+ if tag:
+ logging.info(f"Joining rendezvous '{tag}'...")
+
+ sizes = all_gather(size)
+
+ max_size = torch.max(sizes)
+ mark_step()
+
+ # If all payloads are empty, return immediately to avoid more CPU transfers
+ if max_size.item() < 1:
+ return [b'' for _ in range(sizes.size()[0])]
+
+ padded_data = torch.nn.functional.pad(data, (
+ 0,
+ max_size.item() - size.item(),
+ )).to(xla_device())
+ raw_data = all_gather(padded_data)
+ data_list = torch.split(raw_data, max_size)
+
+ payloads = [d[:sz] for d, sz in zip(data_list, sizes.cpu())]
+ mark_step()
+
+ return [bytes(p.cpu().tolist()) for p in payloads]
+
+
+[docs]def rendezvous(tag: str,
+ payload: bytes = b'',
+ replicas: List[int] = []) -> List[bytes]:
+ """Waits for all the mesh clients to reach the named rendezvous.
+
+ Note: PJRT does not support the XRT mesh server, so this is effectively an
+ alias to `xla_rendezvous`.
+
+ Args:
+ tag (string): The name of the rendezvous to join.
+ payload (bytes, optional): The payload to be sent to the rendezvous.
+ replicas (list, int): The replica ordinals taking part of the rendezvous.
+ Empty means all replicas in the mesh.
+ Default: []
+
+ Returns:
+ The payloads exchanged by all the other cores, with the payload of core
+ ordinal `i` at position `i` in the returned tuple.
+
+ Example:
+
+ >>> import torch_xla.core.xla_model as xm
+ >>> xm.rendezvous('example')
+ """
+ return xla_rendezvous(payload, replicas or None, tag=tag)
+
+
+def do_on_ordinals(
+ target: Callable[..., Any],
+ data: Union[Tuple, Any] = (),
+ ordinals: Union[List[int], Set[int], int] = (0,)
+) -> Optional[Any]:
+ """Runs a function only on a given set of ordinals.
+
+ Args:
+ target (callable): The function to be run on `ordinals`.
+ data: Any input data for the `target` function which contains tensors. All
+ the XLA tensors used by the `target` function must be passed in this
+ argument. Every other data used by the function can be captured by the
+ Python interpreter as usual.
+ Default: ()
+ ordinals (list, int): The list/set of ordinals where the `target` function
+ should run.
+ Default: (0,)
+
+ Returns:
+ In the ordinals that ran the `target` function, the function return value,
+ otherwise `None`.
+ """
+ running = runtime.global_ordinal() in ordinals
+ cpu_data = _maybe_convert_to_cpu(data, convert=running)
+ if running:
+ result = target(*cpu_data)
+ else:
+ result = None
+ rendezvous('torch_xla.core.xla_model.do_on_ordinals')
+ return result
+
+
+[docs]def mesh_reduce(tag: str, data,
+ reduce_fn: Callable[..., Any]) -> Union[Any, ToXlaTensorArena]:
+ """Performs an out-of-graph client mesh reduction.
+
+ Args:
+ tag (string): The name of the rendezvous to join.
+ data: The data to be reduced. The `reduce_fn` callable will receive a list
+ with the copies of the same data coming from all the mesh client processes
+ (one per core).
+ reduce_fn (callable): A function which receives a list of `data`-like
+ objects and returns the reduced result.
+
+ Returns:
+ The reduced value.
+
+ Example:
+
+ >>> import torch_xla.core.xla_model as xm
+ >>> import numpy as np
+ >>> accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
+ """
+ cpu_data = _maybe_convert_to_cpu(data)
+ bio = io.BytesIO()
+ torch.save(cpu_data, bio)
+ xdata = rendezvous(tag, bio.getvalue())
+ xldata = []
+ for xd in xdata:
+ xbio = io.BytesIO(xd)
+ xldata.append(torch.load(xbio))
+ return reduce_fn(xldata) if xldata else cpu_data
+
+
+[docs]def set_rng_state(seed: int, device: Optional[str] = None):
+ """Sets the random number generator state.
+
+ Args:
+ seed (integer): The state to be set.
+ device (string, optional): The device where the RNG state needs to be set.
+ If missing the default device seed will be set.
+ """
+ if device is None:
+ device = torch_xla._XLAC._xla_get_default_device()
+ torch_xla._XLAC._xla_set_rng_seed(seed, str(device) if device else '')
+
+
+[docs]def get_rng_state(device: Optional[str] = None) -> int:
+ """Gets the current running random number generator state.
+
+ Args:
+ device (string, optional): The device whose RNG state needs to be retrieved.
+ If missing the default device seed will be set.
+
+ Returns:
+ The RNG state, as integer.
+ """
+ if device is None:
+ device = torch_xla._XLAC._xla_get_default_device()
+ return torch_xla._XLAC._xla_get_rng_seed(str(device) if device else '')
+
+
+@contextlib.contextmanager
+def fork_rng(device: Optional[str] = None, enabled: bool = True):
+ """
+ Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in.
+ Args:
+ device (string, optional): The device where the RNG state needs to be set. If missing the default device seed will be set.
+ enabled (bool): if ``False``, the RNG is not forked. This is a convenience argument for easily disabling the context manager without having to delete it and unindent your Python code under it.
+ """
+ if not enabled:
+ yield
+ return
+
+ if device is None:
+ device = torch_xla._XLAC._xla_get_default_device()
+ xla_rng_state = get_rng_state(device=device)
+
+ try:
+ yield
+ finally:
+ set_rng_state(xla_rng_state, device=device)
+
+
+class MemoryInfo(TypedDict):
+ bytes_used: str
+ bytes_limit: int
+
+
+[docs]def get_memory_info(device: Optional[torch.device] = None) -> MemoryInfo:
+ """Retrieves the device memory usage.
+
+ Args:
+ device: Optional[torch.device] The device whose memory information are requested.
+ If not passed will use the default device.
+
+ Returns:
+ MemoryInfo dict with memory usage for the given device.
+
+ Example:
+
+ >>> xm.get_memory_info()
+ {'bytes_used': 290816, 'bytes_limit': 34088157184}
+ """
+ if device == None:
+ device = xla_device()
+ return torch_xla._XLAC._xla_memory_info(str(device))
+
+
+def optimization_barrier_(tensors: List[torch.Tensor]):
+ """Blocks xla compiler from moving computations across this barrier. The common
+ use case would be blocking xla common-subexpression elimination pass from undoing
+ the gradient checkpointing.
+
+ Args:
+ tensors (List[torch.Tensor]): List of `torch.Tensor` to add barrier to.
+ """
+ torch_xla._XLAC._xla_optimization_barrier_(tensors)
+
+
+def broadcast_master_param(model: torch.nn.Module) -> None:
+ """
+ Broadcast the model parameters from master process to other processes
+ """
+ parameters_and_buffers = list(
+ itertools.chain(model.parameters(), model.buffers()))
+ collective_broadcast(parameters_and_buffers)
+ mark_step()
+
+import torch_xla
+
+
+[docs]def counter_names():
+ """Retrieves all the currently active counter names."""
+ return torch_xla._XLAC._xla_counter_names()
+
+
+[docs]def counter_value(name):
+ """Returns the value of an active counter.
+
+ Args:
+ name (string): The name of the counter whose value needs to be retrieved.
+
+ Returns:
+ The counter value as integer.
+ """
+ return torch_xla._XLAC._xla_counter_value(name)
+
+
+def clear_counters():
+ """Clear the value of all counters.
+ """
+ return torch_xla._XLAC._clear_xla_counters()
+
+
+[docs]def metric_names():
+ """Retrieves all the currently active metric names."""
+ return torch_xla._XLAC._xla_metric_names()
+
+
+[docs]def metric_data(name):
+ """Returns the data of an active metric.
+
+ Args:
+ name (string): The name of the metric whose data needs to be retrieved.
+
+ Returns:
+ The metric data, which is a tuple of (TOTAL_SAMPLES, ACCUMULATOR, SAMPLES).
+ The `TOTAL_SAMPLES` is the total number of samples which have been posted to
+ the metric. A metric retains only a given number of samples (in a circular
+ buffer).
+ The `ACCUMULATOR` is the sum of the samples over `TOTAL_SAMPLES`.
+ The `SAMPLES` is a list of (TIME, VALUE) tuples.
+ """
+ return torch_xla._XLAC._xla_metric_data(name)
+
+
+def clear_metrics():
+ """Clear the value of all metrics.
+ """
+ return torch_xla._XLAC._clear_xla_metrics()
+
+
+def clear_all():
+ """Clear the value of all metrics and all counters.
+ """
+ clear_metrics()
+ clear_counters()
+
+
+[docs]def metrics_report():
+ """Retrieves a string containing the full metrics and counters report."""
+ return torch_xla._XLAC._xla_metrics_report()
+
+
+[docs]def short_metrics_report(counter_names: list = None, metric_names: list = None):
+ """Retrieves a string containing the full metrics and counters report.
+
+ Args:
+ counter_names (list): The list of counter names whose data needs to be printed.
+ metric_names (list): The list of metric names whose data needs to be printed.
+ """
+ if not counter_names:
+ counter_names = ['CachedCompile', 'MarkStep', 'DynamoSyncInputExecuteTime']
+ if not metric_names:
+ metric_names = [
+ 'CompileTime', 'ExecuteTime', 'ExecuteReplicatedTime',
+ 'TransferToDeviceTime', 'TransferFromDeviceTime'
+ ]
+ return torch_xla._XLAC._short_xla_metrics_report(counter_names, metric_names)
+
+
+def executed_fallback_ops():
+ """Retrieves a list of operations that were run in fallback mode."""
+ return torch_xla._XLAC._get_executed_fallback_ops()
+
+import itertools
+import queue
+import threading
+import torch
+import torch_xla
+import torch_xla.debug.profiler as xp
+import torch_xla.utils.keyd_queue as kq
+import torch_xla.utils.utils as xu
+import torch_xla.core.xla_model as xm
+
+
+class PerDeviceQueue(object):
+
+ def __init__(self, device, loader_prefetch_size, device_prefetch_size):
+ self.device = device
+ self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size)
+ self.queue = kq.Queue(maxsize=device_prefetch_size)
+ self.close_queue_count = itertools.count()
+
+
+class PerDeviceLoader(object):
+
+ def __init__(self, loader, device):
+ self._loader = loader
+ self._device = device
+ self._mark_step_batch_count = loader.batches_per_execution - 1
+ self._batches_yielded = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self.next()
+
+ def __len__(self):
+ return self._loader.per_device_samples()
+
+ def next(self):
+ if xp.get_tracer_marked_step():
+ xp.set_tracer_marked_step(False)
+ self._batches_yielded += 1
+ else:
+ if self._mark_step_batch_count <= self._batches_yielded:
+ self._batches_yielded = 0
+ xm.mark_step()
+ else:
+ self._batches_yielded += 1
+
+ item = self._loader.next_item(self._device)
+ if item is None:
+ if not self._loader._exception_queue.empty():
+ raise self._loader._exception_queue.get()
+ xm.mark_step()
+ raise StopIteration
+ return item
+
+
+class ParallelLoader(object):
+ """Wraps an existing PyTorch DataLoader with background data upload.
+
+ Args:
+ cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
+ wrapped.
+ devices (`torch.device`...): The list of devices where the data has to be
+ sent. The i-th sample returned by the `loader` will be sent to `devices[i
+ % len(devices)]`.
+ batchdim (int, optional): The dimension which is holding the batch size.
+ Default: 0
+ loader_prefetch_size (int, optional): The max capacity of the queue used by
+ the thread which is reading samples from the `loader`, to be processed by
+ the worker threads which upload data to the devices.
+ Default: 16
+ device_prefetch_size (int, optional): The max size of the per-device queues,
+ where the worker threads deposit tensors which have already been sent to
+ devices.
+ Default: 8
+ host_to_device_transfer_threads (int, optional): The number of threads that
+ work in parallel to transfer data from loader queue to device queue.
+ Default: 1
+ input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding
+ spec to apply to compatible input tensors after loading.
+ Default: None
+ """
+
+ def __init__(self,
+ cpu_loader,
+ devices,
+ batchdim=0,
+ batches_per_execution=1,
+ loader_prefetch_size=16,
+ device_prefetch_size=8,
+ host_to_device_transfer_threads=1,
+ input_sharding=None):
+ self._cpu_loader = cpu_loader
+ self._devices = [torch.device(x) for x in devices]
+ self._batchdim = batchdim
+ self._batches_per_execution = batches_per_execution
+ self._done = False
+ self._queues = dict()
+ self._exception_queue = queue.Queue()
+ self._input_sharding = input_sharding
+ for device in self._devices:
+ self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
+ device_prefetch_size)
+ thread = threading.Thread(target=self._loader_worker)
+ thread.daemon = True
+ thread.start()
+ for dqueue in self._queues.values():
+ for i in range(host_to_device_transfer_threads):
+ thread = threading.Thread(
+ target=self._worker,
+ args=(
+ dqueue,
+ host_to_device_transfer_threads,
+ ))
+ thread.daemon = True
+ thread.start()
+
+ def per_device_loader(self, device):
+ """Retrieves the loader iterator object for the given device.
+
+ Args:
+ device (`torch.device`): The device whole loader is being requested.
+
+ Returns:
+ The loader iterator object for the `device`. This is not a
+ `torch.utils.data.DataLoader` interface, but a Python iterator which
+ returns the same tensor data structure as returned by the wrapped
+ `torch.utils.data.DataLoader`, but residing on XLA devices.
+ """
+ return PerDeviceLoader(self, torch.device(device))
+
+ def per_device_samples(self):
+ return len(self._loader) // len(self._devices)
+
+ def next_item(self, device):
+ dqueue = self._queues[device]
+ return dqueue.queue.get()
+
+ def close(self):
+ self._done = True
+ for dqueue in self._queues.values():
+ dqueue.queue.close()
+ dqueue.cpu_loader_queue.close()
+
+ @property
+ def batches_per_execution(self):
+ return self._batches_per_execution
+
+ def _loader_worker(self):
+ queues = list(self._queues.values())
+ data_iter = enumerate(self._cpu_loader)
+ batch = []
+ while not self._done:
+ try:
+ _, data = next(data_iter)
+ except StopIteration:
+ break
+ batch.append(data)
+ if len(batch) == len(self._devices):
+ for queue_no, device_batch in enumerate(batch):
+ queues[queue_no].cpu_loader_queue.put(device_batch)
+ batch = []
+ for dqueue in queues:
+ dqueue.cpu_loader_queue.close_write()
+
+ def _get_batch(self, dqueue):
+ batch = []
+ while len(batch) < dqueue.queue.max_size():
+ item = dqueue.cpu_loader_queue.get()
+ if item is None:
+ break
+ batch.append(item)
+ return batch
+
+ def send_cpu_data_to_device(self, batches, device):
+ """Move batch to device.
+ Args:
+ batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch
+ present in the cpu memory
+ device: TPU device where the batch should be moved
+
+ Returns:
+ result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the
+ input batch is a dict. Otherwise, returns a list of torch.Tensor.
+ """
+ result = None
+ if isinstance(self._input_sharding, dict):
+ if not isinstance(batches[0], dict):
+ raise ValueError(
+ f"input batch should be a dict when input sharding is a dict.")
+ result = []
+ for batch in batches:
+ xla_batch = {}
+ missing_keys = []
+ for key, tensor in batch.items():
+ assert type(tensor) == torch.Tensor
+ sharding_spec = None
+ if self._input_sharding:
+ if key not in self._input_sharding:
+ missing_keys.append(key)
+ continue
+ sharding_spec = self._input_sharding[key]
+
+ # xla_tensor is a list of tensors.
+ xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec)
+ xla_batch[key] = xla_tensor[0]
+ if len(missing_keys) != 0:
+ # Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread.
+ raise KeyError(
+ f"Keys: {missing_keys} are missing from input_sharding.")
+ result.append(xla_batch)
+ else:
+ result = xm.send_cpu_data_to_device(batches, device, self._input_sharding)
+ return result
+
+ def _worker(self, dqueue, host_to_device_transfer_threads):
+ device = torch.device(dqueue.device)
+ while True:
+ batch = self._get_batch(dqueue)
+ if not batch:
+ break
+ try:
+ batch = self.send_cpu_data_to_device(batch, device)
+ except Exception as e:
+ # _worker is being run in a daemon thread, raise the error
+ # will not work. Put the error in an error queue instead.
+ self._exception_queue.put(e)
+ break
+ for data in batch:
+ dqueue.queue.put(data)
+ close_queue_count = next(dqueue.close_queue_count)
+ if close_queue_count == host_to_device_transfer_threads - 1:
+ dqueue.queue.close_write()
+
+
+[docs]class MpDeviceLoader(object):
+ """Wraps an existing PyTorch DataLoader with background data upload.
+
+ This class should only be using with multi-processing data parallelism. It will wrap
+ the dataloader passed in with ParallelLoader and return the per_device_loader for the
+ current device.
+
+ Args:
+ loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
+ wrapped.
+ device (`torch.device`...): The device where the data has to be sent.
+ kwargs: Named arguments for the `ParallelLoader` constructor.
+
+ Example:
+
+ >>> device = torch_xla.device()
+ >>> train_device_loader = MpDeviceLoader(train_loader, device)
+ """
+
+ def __init__(self, loader, device, **kwargs):
+ self._loader = loader
+ self._device = device
+ self._parallel_loader_kwargs = kwargs
+
+ def __iter__(self):
+ parallel_loader = ParallelLoader(self._loader, [self._device],
+ **self._parallel_loader_kwargs)
+ return parallel_loader.per_device_loader(self._device)
+
+ def __len__(self):
+ return len(self._loader)
+
+import os
+from collections import OrderedDict, defaultdict
+from dataclasses import dataclass, field
+import torch
+import torch_xla
+import torch_xla.core.xla_model as xm
+import torch_xla._internal.utils as _utils
+from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
+import torch_xla.runtime as xr
+
+import numpy as np
+import functools
+import itertools
+from typing import Tuple, Union, List, Sequence, Any, Optional, Set
+from enum import IntEnum
+
+
+[docs]class Mesh:
+ """Describe the logical XLA device topology mesh and the underlying resources.
+
+ Args:
+ device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped
+ to an `mesh_shape` array, filling the elements using C-like index order.
+
+ mesh_shape (Tuple[int, ...]): A int tuple describing the logical topology shape
+ of the device mesh, and each element describes the number of devices in
+ the corresponding axis.
+
+ axis_names (Tuple[str, ...]): A sequence of resource axis names to be assigned to the dimensions
+ of the `devices` argument. Its length should match the rank of `devices`.
+
+ Example:
+
+ >>> mesh_shape = (4, 2)
+ >>> num_devices = len(xm.get_xla_supported_devices())
+ >>> device_ids = np.array(range(num_devices))
+ >>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
+ >>> mesh.get_logical_mesh()
+ >>> array([[0, 1],
+ [2, 3],
+ [4, 5],
+ [6, 7]])
+ >>> mesh.shape()
+ OrderedDict([('x', 4), ('y', 2)])
+ """
+
+ device_ids: np.ndarray
+ mesh_shape: Tuple[int, ...]
+ axis_names: Tuple[str, ...]
+
+ def __init__(self,
+ device_ids: Union[np.ndarray, List],
+ mesh_shape: Tuple[int, ...],
+ axis_names: Tuple[str, ...] = None):
+ if not isinstance(device_ids, np.ndarray):
+ device_ids = np.array(device_ids)
+ assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
+ assert axis_names is None or (len(set(axis_names)) == len(axis_names))
+ assert (len(device_ids) == np.prod(mesh_shape))
+ assert len(device_ids) == len(np.unique(device_ids))
+ self.device_ids = device_ids
+ self.mesh_shape = mesh_shape
+ self.axis_names = axis_names
+ assert all(d < self.size() for d in device_ids)
+
+ def size(self):
+ return np.prod(self.mesh_shape)
+
+ def shape(self):
+ if self.axis_names is None:
+ return OrderedDict(
+ (dim, size) for dim, size in enumerate(self.mesh_shape))
+ return OrderedDict(
+ (name, size) for name, size in zip(self.axis_names, self.mesh_shape))
+
+ def get_logical_mesh(self):
+ return self.device_ids.reshape(self.mesh_shape)
+
+ def get_axis_name_idx(self, name: str) -> int:
+ if name not in self.axis_names:
+ return None
+ return self.axis_names.index(name)
+
+ @functools.lru_cache(maxsize=None)
+ def _get_op_sharding_args(self, partition_spec: Tuple):
+ partition_spec = _translate_named_partition_spec(self, partition_spec)
+ flat_specs = np.hstack([d for d in partition_spec])
+ specs = [d for d in flat_specs if d is not None]
+ assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
+ f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
+ assert len(specs) == len(np.unique(specs)), \
+ f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."
+
+ tile_assignment = _get_tile_assignment(self, partition_spec)
+ if len(tile_assignment.shape) > len(partition_spec):
+ # Use partial replication for sharding a tensor over a higher-rank mesh
+ sharding_type = ShardingType.PARTIAL
+ else:
+ sharding_type = _get_sharding_type(partition_spec, self.size())
+ replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
+ group_assignment, replication_groups = _get_group_assignment(
+ sharding_type, tile_assignment, len(partition_spec), replicate_dims)
+
+ tile_assignment = tile_assignment.tolist()
+ sharding_type = int(sharding_type)
+ return tile_assignment, group_assignment, replication_groups, sharding_type
+
+ @functools.lru_cache(maxsize=None)
+ def get_op_sharding(self,
+ partition_spec: Tuple) -> torch_xla._XLAC.OpSharding:
+ """
+ Return the OpSharding for the given partition spec. This is an expensive
+ operation as the mesh grows, so the value is cached for reuse.
+ """
+ # For scalar tensors, it can only be replicated.
+ # We have made sure len(t.shape) == len(partition_spec)
+ # in mark_sharding API.
+ if len(partition_spec) == 0:
+ return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED)
+
+ tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args(
+ partition_spec)
+ return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
+ replication_groups, sharding_type)
+
+
+_GLOBAL_MESH: Mesh = None
+
+
+[docs]def set_global_mesh(mesh: Mesh):
+ """
+ Set the global mesh that can be used for the current process.
+
+ Args:
+ mesh: (Mesh) Mesh object that will be the global mesh.
+
+ Example:
+
+ >>> import torch_xla.distributed.spmd as xs
+ >>> mesh = xs.get_1d_mesh("data")
+ >>> xs.set_global_mesh(mesh)
+ """
+ global _GLOBAL_MESH
+ _GLOBAL_MESH = mesh
+
+
+[docs]def get_global_mesh() -> Optional[Mesh]:
+ """
+ Get the global mesh for the current process.
+
+ Returns:
+ mesh: (Optional[Mesh]) Mesh object if global mesh is set, otherwise return None.
+
+ Example:
+
+ >>> import torch_xla.distributed.spmd as xs
+ >>> xs.get_global_mesh()
+ """
+ global _GLOBAL_MESH
+ return _GLOBAL_MESH
+
+
+[docs]def get_1d_mesh(axis_name: Optional[str] = None) -> Mesh:
+ """
+ Helper function to return the mesh with all devices in one dimension.
+
+ Args:
+ axis_name: (Optional[str]) optional string to represent the axis name of the mesh
+
+ Returns:
+ Mesh: Mesh object
+
+ Example:
+
+ >>> # This example is assuming 1 TPU v4-8
+ >>> import torch_xla.distributed.spmd as xs
+ >>> mesh = xs.get_1d_mesh("data")
+ >>> print(mesh.mesh_shape)
+ (4,)
+ >>> print(mesh.axis_names)
+ ('data',)
+ """
+ num_devices = xr.global_runtime_device_count()
+ mesh_shape = (num_devices,)
+ device_ids = np.array(range(num_devices))
+ if axis_name == None:
+ return Mesh(device_ids, mesh_shape)
+ else:
+ return Mesh(device_ids, mesh_shape, (axis_name,))
+
+
+# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4ƒ
+[docs]class HybridMesh(Mesh):
+ """Creates a hybrid device mesh of devices connected with ICI and DCN networks.
+ The shape of logical mesh should be ordered by increasing network-intensity
+ e.g. [replica, data, model] where mdl has the most network communication
+ requirements.
+
+ Args:
+ ici_mesh_shape: shape of the logical mesh for inner connected devices.
+ dcn_mesh_shape: shape of logical mesh for outer connected devices.
+
+ Example:
+
+ >>> # This example is assuming 2 slices of v4-8.
+ >>> ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
+ >>> dcn_mesh_shape = (2, 1, 1)
+ >>> mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
+ >>> print(mesh.shape())
+ >>> >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
+ """
+ ici_mesh_shape: Tuple[int, ...]
+ dcn_mesh_shape: Tuple[int, ...]
+
+ def __init__(self,
+ *,
+ ici_mesh_shape: Tuple[int, ...],
+ dcn_mesh_shape: Tuple[int, ...] = None,
+ axis_names: Tuple[str, ...] = None):
+ if dcn_mesh_shape == None:
+ dcn_mesh_shape = tuple([1] * len(ici_mesh_shape))
+ assert len(ici_mesh_shape) == len(dcn_mesh_shape)
+ mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)])
+ self.device_attributes = xr.global_runtime_device_attributes()
+ self.device_attributes.sort(
+ key=lambda attr: _utils.parse_xla_device(attr['name'])[1])
+
+ if 'slice_index' in self.device_attributes[0] and np.prod(
+ dcn_mesh_shape) == 1:
+ raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice')
+ if 'slice_index' not in self.device_attributes[0] and np.prod(
+ dcn_mesh_shape) > 1:
+ raise ValueError('Invalid dcn_mesh_shape for single slice mesh')
+ self.ici_mesh_shape = ici_mesh_shape
+ self.dcn_mesh_shape = dcn_mesh_shape
+ if np.prod(dcn_mesh_shape) > 1 and 'slice_index' in self.device_attributes[
+ 0]: # multislice
+ mesh = self._create_hybrid_device_mesh(self.ici_mesh_shape,
+ self.dcn_mesh_shape)
+ else:
+ mesh = self._create_device_mesh(self.ici_mesh_shape)
+ device_ids = mesh.flatten()
+ super().__init__(device_ids, mesh_shape, axis_names)
+
+ # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172
+ def _get_physical_tpu_mesh(self, devices: Sequence[int]) -> np.ndarray:
+ r"""Rearrange TPU devices in a slice into a physical mesh.
+
+ Args:
+ devices: A list of device logical ordinals in a TPU slice.
+
+ Returns:
+ A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
+ v2 and v3, global_z is instead cores_per_chip (i.e., 2).
+ """
+ assert xm.xla_device_hw(xm.xla_device()) == 'TPU'
+ # coords is a 3-dims tuple representing the device in physical mesh
+ device_coords = [self.device_attributes[d]['coords'] for d in devices]
+ dims = tuple(d + 1 for d in max(device_coords))
+ out = np.empty(dims, dtype=int)
+ for coords, d in zip(device_coords, devices):
+ out[coords[0], coords[1], coords[2]] = d
+ return out
+
+ # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.
+ def _create_device_mesh_for_nd_torus(
+ self, physical_mesh: np.ndarray,
+ mesh_shape: Sequence[int]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
+ """Assigns logical parallelism axes to physical axes of an N-D torus network.
+
+ Given logical parallelism axes with sizes in `mesh_shape` and devices in an
+ N-dimensional torus network represented by `physical_mesh`, maps each logical
+ axis to one or more physical axes. Prefer to map more-performance-sensitive
+ logical axes to larger numbers of physical axes to maximize the bandwidth
+ available to them. Also prefer to assign logical axes to multiple physical
+ axes of the same size (e.g., a 2D square) rather than multiple physical axes
+ of different sizes when possible.
+
+ Note that this routine will never split a physical axis over more than one
+ logical axis (which would reduce total usable bandwidth but may sometimes be
+ desired anyway). As a result, it will error out in cases where this is
+ necessary to produce a valid mapping.
+
+ Let's use a concrete example to explain the concepts and considerations.
+
+ As an example, suppose the logical mesh is [data, model], for data and model
+ parallelism respectively. Also suppose that data parallelism is less
+ performance sensitive than model parallelism. Consider a 3D TPU pod slice of
+ shape 4x4x16, represented by a physical mesh of shape (4, 4, 16).
+
+ A TPU pod slice has equal bandwidth along all axes with wraparound links, but
+ a 2D plane of size 4x4 may have faster XLA collective implementations than a
+ non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want
+ the more performance sensitive `model` axis to be mapped to the 4x4 XY plane.
+
+ Args:
+ physical_mesh: a np.ndarray of devices in the shape of the N-D torus
+ physical topology.
+ mesh_shape: shape of the logical mesh (size of the various logical
+ parallelism axes), with axes ordered by increasing network intensity.
+
+ Returns:
+ An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with
+ each logical parallelism axis mapped to one or more physical mesh axes.
+ The axis assignment (a list of length num_logical_axes, whose elements
+ are tuples representing physical axis indices).
+ """
+ # Remaining physical axes to be assigned to logical axes.
+ assignable_physical_mesh = list(physical_mesh.shape)
+ # Map each logical axis to a subset of physical axes.
+ assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape]
+ # Assign logical axes from highest network intensity to lowest.
+ # `mesh_shape` is assumed to ordered by lowest network intensity first, so
+ # reverse it first.
+ # Assigns devices to 2D or 3D logical mesh.
+ for logical_axis_index, logical_axis_size in reversed(
+ list(enumerate(mesh_shape))):
+ for num_axes in range(3, 0, -1):
+ # map a combination of devices in physical axes to the logical axis.
+ axes = itertools.combinations(assignable_physical_mesh, num_axes)
+ indices = itertools.combinations(
+ range(len(assignable_physical_mesh)), num_axes)
+ for c_axes, c_indices in zip(axes, indices):
+ if np.prod(c_axes) == logical_axis_size:
+ assignment[logical_axis_index] = c_indices
+ # Zero the assigned physical axes.
+ assignable_physical_mesh = [
+ 0 if i in c_indices else v
+ for i, v in enumerate(assignable_physical_mesh)
+ ]
+ break
+ if assignment[logical_axis_index]:
+ # We already found an assignment from one candidate above.
+ break
+ else:
+ # If the num_axes for loop did not break, i.e. none of the candidates work
+ # goto here with this while-else construct.
+ if logical_axis_size > 1:
+ raise NotImplementedError(
+ 'Failed to find assignment for logical_axis_index'
+ f' {logical_axis_index} of size {logical_axis_size} with remaining'
+ f' assignable mesh {assignable_physical_mesh}. The size of each'
+ ' axis in your logical mesh must be equal to the product of'
+ ' some subset of the physical mesh axis sizes. E.g logical mesh (4,'
+ ' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.'
+ )
+ # Flatten the assignment
+ transpose: List[int] = []
+ for x in assignment:
+ for y in x:
+ transpose.append(int(y))
+ return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment
+
+ def _create_device_mesh(self,
+ mesh_shape: Sequence[int],
+ devices: Sequence[Any] = None) -> Sequence[int]:
+ """Creates a performant device mesh.
+
+ Args:
+ mesh_shape: shape of logical mesh, ordered by increasing network-intensity
+ e.g. [replica, data, mdl] where mdl has the most network communication
+ requirements.
+ devices: optionally, the devices to construct a mesh for.
+
+ Returns:
+ A np.ndarray of devices with mesh_shape as its shape.
+ """
+
+ if devices is None:
+ devices = np.arange(xr.global_runtime_device_count())
+ if np.prod(mesh_shape) != len(devices):
+ raise ValueError(
+ f'Number of devices {len(devices)} must equal the product '
+ f'of mesh_shape {mesh_shape}')
+ physical_mesh = self._get_physical_tpu_mesh(devices)
+ device_mesh, assignment = self._create_device_mesh_for_nd_torus(
+ physical_mesh, mesh_shape)
+ return device_mesh
+
+ # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288.
+ def _create_hybrid_device_mesh(
+ self, ici_mesh_shape: Sequence[int],
+ dcn_mesh_shape: Sequence[int]) -> Sequence[int]:
+ """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.
+
+ Args:
+ ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered
+ by increasing network intensity, e.g. [replica, data, mdl] where mdl has
+ the most network communication requirements.
+ dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
+ in the same order as mesh_shape.
+
+ Returns:
+ A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape
+ that can be fed into HybridMesh for hybrid parallelism.
+ """
+ granule_dict = defaultdict(list)
+ for d, dev in enumerate(self.device_attributes):
+ granule_dict[dev['slice_index']].append(d)
+ # sorts devices based on slice_index.
+ granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
+ if np.prod(dcn_mesh_shape) != len(granules):
+ raise ValueError(
+ f'Number of slices {len(granules)} must equal the product of '
+ f'dcn_mesh_shape {dcn_mesh_shape}')
+ # creates a seperate internal mesh for each slice.
+ per_granule_meshes = [
+ self._create_device_mesh(ici_mesh_shape, granule)
+ for granule in granules
+ ]
+ granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
+ blocks = np.vectorize(
+ lambda i: per_granule_meshes[i], otypes=[object])(
+ granule_mesh)
+ device_mesh = np.block(blocks.tolist())
+ return device_mesh
+
+
+class ShardingType(IntEnum):
+ # ShardingType enum ID maps to OpSharidng.Type (https://shorturl.at/pvAJX)
+ REPLICATED = 0
+ MAXIMAL = 1
+ TUPLE = 2
+ TILED = 3
+ MANUAL = 4
+ PARTIAL = 5
+ UNKNOWN = 6 # implicit replication. TODO(yeounoh) wait for auto-sharding support
+
+
+def _get_sharding_type(partition_spec: Tuple[Union[int, None]],
+ num_devices: int) -> ShardingType:
+ sharding_type = ShardingType.TILED
+ if num_devices == 1:
+ sharding_type = ShardingType.MAXIMAL
+ elif all(d is None for d in partition_spec):
+ sharding_type = ShardingType.REPLICATED
+ elif any(d is None for d in partition_spec):
+ sharding_type = ShardingType.PARTIAL
+ return sharding_type
+
+
+def _get_tile_assignment(
+ mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int,
+ None]]) -> np.ndarray:
+ """
+ Permute the given mesh to create the tile assignment based on the partition
+ spec. Returns the tiling assignment as a numpy ndarray.
+
+ If the input partition_spec combines multiple logical mesh axes over a single
+ tensor axis, the resulting tiling assignment will combine the specified axes
+ into a single axis.
+ """
+ # Flatten the partition spec and ensure that it is fully specified over the
+ # mesh for permutation.
+ tiled_dims = [x for x in partition_spec if x is not None]
+ permutation = np.hstack(tiled_dims).tolist() if tiled_dims else []
+ missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation))
+ tile_assignment = mesh.get_logical_mesh().transpose(permutation +
+ missing_axes)
+
+ # For any tuples in the partition_spec, the grouped axes will be adjacent
+ # after the permutation. Combine these dimensions into a single axis.
+ for i, spec in enumerate(tiled_dims):
+ if isinstance(spec, tuple):
+ shape = tile_assignment.shape
+ tile_assignment = tile_assignment.reshape(shape[:i] + (-1,) +
+ shape[i + len(spec):])
+
+ return tile_assignment
+
+
+# Produce group assignment for partial replication. Partial replication tiles
+# groups (a.k.a. sub-groups) where the shards are fully replicated within each
+# sub-group. `replication_groups` is a list of groups as lists, where each group
+# contains the participating device IDs. `group_assignment` describes the group
+# placement and the overall mesh, where each element is the group ID.
+# The tile_assignment should be the result of `_get_tile_assignment` so that all
+# tiled dimensions are in the first axes and replicated dimensions are in the
+# remaining axes.
+def _get_group_assignment(sharding_type: ShardingType,
+ tile_assignment: np.ndarray, tensor_rank: int,
+ replicate_dims: Set[int]) -> Tuple[List, List]:
+ group_assignment = list()
+ replication_groups = list()
+ if sharding_type is ShardingType.PARTIAL:
+ # Shard across groups and replicate within subgroups; replicated dims
+ # will be used to group replication devices.
+ tile_shape = tile_assignment.shape
+ # When creating the tile assignment, the mesh is permuted so that the first
+ # few axes are used for tiling.
+ tile_dims = range(tensor_rank - len(replicate_dims))
+ group_list = [tile_assignment]
+ for d in tile_dims:
+ _group_list = list()
+ for group_members in group_list:
+ _group_list += np.split(group_members, tile_shape[d], d)
+ group_list = _group_list
+ replication_groups = [group.flatten().tolist() for group in group_list]
+
+ mesh_axis = itertools.count()
+ group_tile_shape = [
+ 1 if d in replicate_dims else tile_shape[next(mesh_axis)]
+ for d in range(tensor_rank)
+ ]
+ group_assignment = np.arange(len(replication_groups)).reshape(
+ tuple(group_tile_shape)).tolist()
+ return group_assignment, replication_groups
+
+
+def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
+ _partition_spec = list()
+ for p in partition_spec:
+ if type(p) is tuple:
+ assert not any(type(x) is tuple
+ for x in p), 'Partition spec cannot contain nested tuples'
+ _partition_spec.append(_translate_named_partition_spec(mesh, p))
+ elif (p is None) or (type(p) is int):
+ _partition_spec.append(p)
+ elif type(p) is str:
+ idx = mesh.get_axis_name_idx(p)
+ if idx is None:
+ raise ValueError(f"Axis name {p} is not defined in the given mesh")
+ _partition_spec.append(idx)
+ else:
+ raise ValueError(
+ f"Spec type {type(p)} is not supported in partition spec")
+ return tuple(_partition_spec)
+
+
+def _mark_manual_sharding(
+ t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor:
+ """
+ This API is meant to be paired with the upcoming pause_spmd&resume_spmd APIs.
+ Don't use it alone.
+ """
+ manual_sharding = torch_xla._XLAC.OpSharding([], [], [], ShardingType.MANUAL)
+ torch_xla._XLAC._mark_manual_sharding(
+ unwrap_sharded_tensor(t), manual_sharding)
+ return wrap_as_sharded_tensor(t)
+
+
+def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
+ partition_spec: Tuple[Union[Tuple, int, str, None]],
+ *,
+ mesh: Mesh = None) -> XLAShardedTensor:
+ """
+ This API enables manual sharding for the given tensor. Manual sharding disables SPMD sharding proporgation and auto
+ partition for the given tensor and all subsequential tensors that produced by an op that uses the given tensor as
+ input, and therefore allows the user to manually call collectives for the tensor and subsequential tensors. It
+ requires the user to provide the partition spec to shard the tensor before enabling the manual sharding. To be noted,
+ the leaf tensors need to pass to disable_manual_sharding before ending the graph.
+ """
+ mesh = get_global_mesh() if mesh is None else mesh
+ t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec)
+ t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t))
+ return wrap_as_sharded_tensor(t)
+
+
+def disable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
+ partition_spec: Tuple[Union[Tuple, int, str, None]],
+ full_shape: torch.Size,
+ *,
+ mesh: Mesh = None) -> XLAShardedTensor:
+ """
+ This API disables manual sharding for the given tensor. The partition_spec and full_shape are used to construct the
+ output tensor as if the input tensor has not been manual sharded.
+ """
+ mesh = get_global_mesh() if mesh is None else mesh
+ t = _mark_manual_sharding(unwrap_sharded_tensor(t))
+ t = torch_xla._XLAC._spmd_shard_to_full_shape(
+ unwrap_sharded_tensor(t), mesh.get_op_sharding(partition_spec),
+ full_shape, t.dtype)
+ return wrap_as_sharded_tensor(t)
+
+
+[docs]def mark_sharding(
+ t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
+ partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor:
+ """
+ Annotates the tensor provided with XLA partition spec. Internally,
+ it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
+
+ Args:
+ t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_spec.
+
+ mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.
+
+ partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or
+ `None`. Each index is an int, str if the mesh axis is named, or tuple of int or str.
+ This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
+ When a tuple is specified, the corresponding input tensor axis will be sharded along all
+ logical axes in the tuple. Note that the order the mesh axes are specified in the tuple
+ will impact the resulting sharding.
+
+ dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding
+ to make itself recognizeable and traceable by dynamo.
+
+ Example:
+
+ >>> import torch_xla.runtime as xr
+ >>> import torch_xla.distributed.spmd as xs
+ >>> mesh_shape = (4, 2)
+ >>> num_devices = xr.global_runtime_device_count()
+ >>> device_ids = np.array(range(num_devices))
+ >>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
+ >>> input = torch.randn(8, 32).to(xm.xla_device())
+ >>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel
+ >>> linear = nn.Linear(32, 10).to(xm.xla_device())
+ >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
+ """
+ num_devices = xr.global_runtime_device_count()
+ assert num_devices > 0, "This requires XLA supported device(s)."
+ assert mesh.size() == num_devices, \
+ f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
+ # We only allow fully specified `partition_spec` to be applicable, as opposed
+ # to filling in the unspecified replicated dims. Fully specified `partiion_spec`
+ # should be of the same rank as `t`. This is to support partial replication
+ # where the group assignment may vary with different input ranks.
+ assert len(t.shape) == len(partition_spec), \
+ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
+
+ op_sharding = mesh.get_op_sharding(partition_spec)
+ annotate_func = torch_xla._XLAC._xla_mark_sharding
+ annotate_func(unwrap_sharded_tensor(t), op_sharding)
+ return wrap_as_sharded_tensor(t)
+
+
+[docs]def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
+ """
+ Clear sharding annotation from the input tensor and return a `cpu` casted tensor. This
+ is a in place operation but will also return the same torch.Tensor back.
+
+ Args:
+ t (Union[torch.Tensor, XLAShardedTensor]): Tensor that we want to clear the sharding
+
+ Return:
+ t (torch.Tensor): tensor that without sharding.
+
+ Example:
+
+ >>> import torch_xla.distributed.spmd as xs
+ >>> torch_xla.runtime.use_spmd()
+ >>> t1 = torch.randn(8,8).to(torch_xla.device())
+ >>> mesh = xs.get_1d_mesh()
+ >>> xs.mark_sharding(t1, mesh, (0, None))
+ >>> xs.clear_sharding(t1)
+ """
+ torch_xla._XLAC._xla_clear_sharding(unwrap_sharded_tensor(t))
+ if isinstance(t, XLAShardedTensor):
+ return t.global_tensor
+ return t
+
+
+def wrap_as_sharded_tensor(
+ t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor:
+ if not isinstance(t, XLAShardedTensor):
+ return XLAShardedTensor(t)
+ return t
+
+
+def unwrap_sharded_tensor(
+ t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
+ if isinstance(t, XLAShardedTensor):
+ return t.global_tensor
+ return t
+
+
+def wrap_if_sharded(x: Any) -> Any:
+ """
+ If the input is a sharded tensor, return an XLAShardedTensor wrapping it.
+ Otherwise, returns the input.
+ """
+ if (isinstance(x, torch.Tensor) and not isinstance(x, XLAShardedTensor) and
+ x.device.type == 'xla' and
+ torch_xla._XLAC._get_xla_sharding_type(x) is not None):
+ return XLAShardedTensor(x)
+ return x
+
+
+@dataclass
+class ShardingSpec:
+ mesh: Mesh
+ partition_spec: Tuple[Union[int, None]]
+ minibatch: Optional[bool] = False
+
+ # Derived fields
+ _tile_assignment: List[int] = field(init=False)
+ _group_assignment: List[int] = field(init=False)
+ _replication_groups: List[int] = field(init=False)
+ _sharding_type: ShardingType = field(init=False)
+
+ def __post_init__(self):
+ mesh = self.mesh
+ partition_spec = _translate_named_partition_spec(mesh, self.partition_spec)
+ tile_assignment = _get_tile_assignment(mesh, partition_spec)
+ self._tile_assignment = tile_assignment.tolist()
+ self._sharding_type = _get_sharding_type(partition_spec,
+ xr.global_runtime_device_count())
+ replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
+ self._group_assignment, self._replication_groups = _get_group_assignment(
+ self._sharding_type, tile_assignment, len(partition_spec),
+ replicate_dims)
+
+ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
+ """
+ Create an XlaShardingSpec for the given tensor. If the tensor is
+ incompatible with the ShardingSpec, returns None.
+ """
+ if not self.can_apply(t):
+ return None
+ return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,
+ self._group_assignment,
+ self._replication_groups,
+ int(self._sharding_type),
+ self.minibatch)
+
+ def can_apply(self, t: torch.Tensor) -> bool:
+ """
+ Test whether the ShardingSpec is compatible with the given torch.Tensor.
+ """
+ return len(t.shape) == len(self.partition_spec)
+
+ def apply(self, t: torch.Tensor):
+ # TODO(yeounoh) use virtual device interface when available.
+ assert (t.device == xm.xla_device())
+ mark_sharding(t, self.mesh, self.partition_spec)
+
+
+class XLAPatchedLinear(torch.autograd.Function):
+ """
+ A patched version of `torch.nn.functional.linear` that uses einsum instead
+ of torch.matmul which will flatten the tensors to 2D and collide the sharded
+ dimensions. The torch.matmul default behavior makes it very hard for XLA compiler
+ to propagate the sharding annotation.
+
+ TODO (alanwaketan): Let's patch it on the dispatcher level.
+ """
+
+ @staticmethod
+ def forward(ctx, input, weight, bias=None):
+ # bias is an optional argument
+ ctx.save_for_backward(input, weight, bias)
+ with torch.no_grad():
+ product = torch.einsum('...n,mn->...m', input, weight)
+ if bias is None:
+ return product
+ return product + bias
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight, bias = ctx.saved_tensors
+ grad_input = grad_weight = grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = torch.einsum('...m,mn->...n', grad_output, weight)
+ if ctx.needs_input_grad[1]:
+ grad_weight = torch.einsum('...m,...n->mn', grad_output, input)
+ if bias is not None and ctx.needs_input_grad[2]:
+ grad_bias = torch.einsum('...m->m', grad_output)
+
+ return grad_input, grad_weight, grad_bias
+
+
+def xla_patched_nn_linear_forward(m, input):
+ return XLAPatchedLinear.apply(input, m.weight, m.bias)
+
+
+def apply_backward_optimization_barrier(m: torch.nn.Module):
+ """
+ Register a full backward hook that apply an optimization barrier to the given module.
+ This will prevent the XLA compiler from fusing the module's backward pass with others.
+ It's useful to prevent gigantic buffers being allocated to synchronize the gradients.
+ """
+
+ def optimization_barrier(module, grad_input, grad_output):
+ from torch_xla.utils.checkpoint import CheckpointFunction
+ gradients = []
+ for param in module.parameters():
+ if param.grad != None:
+ gradients.append(param.grad)
+ xm.optimization_barrier_(
+ CheckpointFunction._extract_tensors_from_list(gradients +
+ list(grad_input)))
+
+ m.register_full_backward_hook(optimization_barrier)
+
+import torch.multiprocessing
+from torch_xla import runtime as xr
+from torch_xla._internal import pjrt
+
+
+[docs]def spawn(fn,
+ args=(),
+ nprocs=None,
+ join=True,
+ daemon=False,
+ start_method='spawn'):
+ """Enables multi processing based replication.
+
+ Args:
+ fn (callable): The function to be called for each device which takes part of
+ the replication. The function will be called with a first argument being
+ the global index of the process within the replication, followed by the
+ arguments passed in `args`.
+ args (tuple): The arguments for `fn`.
+ Default: Empty tuple
+ nprocs (int): The number of processes/devices for the replication. At the
+ moment, if specified, can be either 1 or the maximum number of devices.
+ join (bool): Whether the call should block waiting for the completion of the
+ processes which have being spawned.
+ Default: True
+ daemon (bool): Whether the processes being spawned should have the `daemon`
+ flag set (see Python multi-processing API).
+ Default: False
+ start_method (string): The Python `multiprocessing` process creation method.
+ Default: `spawn`
+
+ Returns:
+ The same object returned by the `torch.multiprocessing.spawn` API. If
+ `nprocs` is 1 the `fn` function will be called directly, and the API will
+ return None.
+ """
+ return pjrt.spawn(fn, nprocs, start_method, args)
+
+
+class MpModelWrapper(object):
+ """Wraps a model to minimize host memory usage when `fork` method is used.
+
+ This class should be used together with the `spawn(..., start_method='fork')`
+ API to minimize the use of host memory.
+ Instead of creating models on each multiprocessing process, hence replicating
+ the model's initial host memory, the model is created once at global scope,
+ and then moved into each device inside the `spawn()` target function.
+ Example::
+
+ WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork())
+
+ def _mp_fn(index, ...):
+ device = xm.xla_device()
+ model = WRAPPED_MODEL.to(device)
+ ...
+
+ torch_xla.launch(_mp_fn, ..., start_method='fork')
+
+ This method has two advantages. First it uses only one copy of the memory
+ pages to host the original model weights, and second it serializes the move
+ of the wrapped model into each device, by lowering the load onto the system
+ memory during the process.
+ """
+
+ def __init__(self, model):
+ """Creates a new `MpModelWrapper` object.
+
+ Args:
+ model (torch.nn.Module): The model to be wrapped. Should be on PyTorch CPU
+ device (which is the default when creating new models).
+ """
+ self._model = model
+ self._lock = torch.multiprocessing.Lock()
+
+ def to(self, device):
+ """Retrieves the model moved onto the specified device.
+
+ Args:
+ device (torch.device): The device where the model should be moved onto.
+ Returns:
+ The model on the specified device.
+ """
+ with self._lock:
+ self._model.to(device)
+ return self._model
+
+
+class MpSerialExecutor(object):
+ """Utility to run a function in a serialized fashion among multi-core processes.
+
+ Example::
+
+ # At global scope.
+ SERIAL_EXEC = xmp.MpSerialExecutor()
+
+ def load_dataset(path):
+ return maybe_download_and_load(path)
+
+ def _mp_fn(index, ...):
+ # Avoid all cores downloading the same data with the serial executor.
+ dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data'))
+ ...
+
+ torch_xla.launch(_mp_fn, ...)
+ """
+
+ def __init__(self):
+ self._lock = torch.multiprocessing.Lock()
+
+ def run(self, fn):
+ """Runs the provided function serialized WRT each per-core process.
+
+ Args:
+ fn (callable): The function to run in a serialized fashion.
+ Returns:
+ The `fn` return value.
+ """
+ with self._lock:
+ return fn()
+
+import functools
+from contextlib import contextmanager
+
+import torch_xla
+import logging
+
+
+[docs]def eager_mode(enable: bool):
+ """Configure torch_xla's default executation mode.
+
+ Under eager mode only functions that was `torch_xla.compile`d will be
+ traced and compiled. Other torch ops will be executed eagerly.
+ """
+ torch_xla._XLAC._set_use_eager_mode(enable)
+
+
+def is_eager_mode() -> bool:
+ """Return True if torch_xla is currently under eager mode
+ """
+ return torch_xla._XLAC._get_use_eager_mode()
+
+
+@contextmanager
+def eager_mode_context(enable: bool):
+ """Context manager to enable/disable the eager mode.
+ """
+ saved_eager_mode = is_eager_mode()
+ eager_mode(enable)
+ try:
+ yield saved_eager_mode
+ finally:
+ eager_mode(saved_eager_mode)
+
+
+def compile(func):
+ # can's use deprecated wrapper at import time due to circular dependency
+ logging.warning(
+ 'torch_xla.experimental.compile is deprecated. Use torch_xla.compile instead.'
+ )
+ return torch_xla.compile(func)
+
+import functools
+import logging
+import os
+import warnings
+from typing import Dict, List, Optional, TypeVar
+
+import torch
+import torch.cuda
+import torch_xla
+import torch_xla.core.xla_env_vars as xenv
+import torch_xla.core.xla_model as xm
+import torch_xla.utils.utils as xu
+import torch_xla._internal.utils as _utils
+import torch_xla._internal.tpu as tpu
+from torch_xla.experimental import plugins
+from torch_xla import runtime
+
+R = TypeVar('R')
+FN = TypeVar('FN')
+
+# Note [Dynamo WORLD_SIEZ and ORDINAL]
+# Belows are workaround to cache the ordinal and world_size such that
+# Dynamo won't do graph breaks when runtime.world_size() and runtime.global_ordinal() are called.
+_WORLD_SIZE = None
+_ORDINAL = None
+
+
+def _init_world_size_ordinal():
+ global _WORLD_SIZE, _ORDINAL
+
+ # Dynamo doesn't support XRT or multithreaded runtime. See Note [V3-8 Threading]
+ if runtime.addressable_device_count() > 1:
+ return
+
+ if _WORLD_SIZE is None:
+ _WORLD_SIZE = runtime.world_size()
+ _ORDINAL = runtime.global_ordinal()
+
+
+def set_device_type(pjrt_device: str) -> None:
+ """Sets the current PjRt device type.
+
+ Must be run before using any XLA devices.
+
+ Args:
+ pjrt_device: 'TPU' or 'CPU'
+ """
+ if torch_xla._XLAC._xla_runtime_is_initialized() and os.environ.get(
+ xenv.PJRT_DEVICE) != pjrt_device:
+ raise RuntimeError(
+ "Can't change device type after XLA runtime is initialized")
+
+ os.environ[xenv.PJRT_DEVICE] = pjrt_device
+
+
+def _maybe_select_default_device():
+ if xu.getenv_as(xenv.PJRT_SELECT_DEFAULT_DEVICE, str,
+ '1') == '0' or xenv.PJRT_DEVICE in os.environ:
+ return
+
+ # Check for libtpu _and_ the TPU device
+ if torch_xla._found_libtpu and tpu.num_available_chips() > 0:
+ logging.warning('libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.')
+ os.environ[xenv.PJRT_DEVICE] = 'TPU'
+ elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0:
+ logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=CUDA')
+ os.environ[xenv.PJRT_DEVICE] = 'CUDA'
+ elif torch.cuda.is_available() and torch.cuda.device_count() > 0:
+ num_devices_str = str(torch.cuda.device_count())
+ logging.warning(
+ 'Found CUDA without GPU_NUM_DEVICES. Defaulting to PJRT_DEVICE=CUDA with GPU_NUM_DEVICES='
+ + num_devices_str)
+ os.environ[xenv.PJRT_DEVICE] = 'CUDA'
+ os.environ[xenv.GPU_NUM_DEVICES] = num_devices_str
+ elif torch_xla._found_libneuronxla:
+ logging.warning('Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.')
+ os.environ[xenv.PJRT_DEVICE] = 'NEURON'
+ else:
+ logging.warning('Defaulting to PJRT_DEVICE=CPU')
+ os.environ[xenv.PJRT_DEVICE] = 'CPU'
+
+
+[docs]def device_type() -> Optional[str]:
+ """Returns the current PjRt device type.
+
+ Selects a default device if none has been configured
+
+ Returns:
+ A string representation of the device.
+ """
+ pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str)
+ return pjrt_device.split('_')[0] if pjrt_device else pjrt_device
+
+
+def is_bf16_supported():
+ """Returns whether torch.bfloat16 is supported on this environment.
+ """
+ try:
+ torch.tensor([1.], dtype=torch.bfloat16, device=xm.xla_device())
+ return True
+ except Exception as e:
+ return False
+
+
+def xla_device(n: Optional[int] = None,
+ devkind: Optional[str] = None) -> torch.device:
+ """Returns an XLA device.
+
+ Args:
+ n: Index of XLA device within visibible devices. If not set, use local
+ ordinal (default 0) to select an addressable device.
+ devkind: Type of device to return. Should match `device_type()`.
+
+ Returns:
+ A `torch.device` representing an XLA device.
+ """
+ if n is None:
+ return torch.device(torch_xla._XLAC._xla_get_default_device())
+
+ devices = xm.get_xla_supported_devices(devkind=devkind)
+ if n > len(devices):
+ raise IndexError('Device index {} out of range in {}'.format(n, devices))
+
+ device = devices[n]
+ torch_xla._XLAC._xla_set_default_device(device)
+ return torch.device(device)
+
+
+[docs]def local_process_count() -> int:
+ """Returns the number of processes running on this host."""
+ return xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_COUNT, int, defval=1)
+
+
+[docs]def global_device_count() -> int:
+ """Returns the total number of devices across all processes/hosts."""
+ return len(torch_xla._XLAC._xla_get_all_devices())
+
+
+[docs]def world_size() -> int:
+ """Returns the total number of processes participating in the job."""
+ global _WORLD_SIZE
+ if _WORLD_SIZE is not None:
+ return _WORLD_SIZE
+ if torch_xla._XLAC._xla_get_replication_devices_count() == 0:
+ _WORLD_SIZE = 1
+ else:
+ _WORLD_SIZE = global_device_count()
+ return _WORLD_SIZE
+
+
+[docs]def local_device_count() -> int:
+ """Returns the total number of devices on this host.
+
+ Assumes each process has the same number of addressable devices.
+ """
+ return local_process_count() * addressable_device_count()
+
+
+[docs]def addressable_device_count() -> int:
+ """Returns the number of devices visible to this process."""
+ return torch_xla._XLAC._xla_num_devices()
+
+
+[docs]def global_ordinal() -> int:
+ """Returns global ordinal of this thread within all processes.
+
+ Global ordinal is in range [0, global_device_count). Global ordinals are not
+ guaranteed to have any predictable relationship to the TPU worker ID nor are
+ they guaranteed to be contiguous on each host."""
+ global _ORDINAL
+ if _ORDINAL is not None:
+ return _ORDINAL
+ return torch_xla._XLAC._xla_get_default_device_ordinal()
+
+
+[docs]def local_ordinal() -> int:
+ """Returns local ordinal of this thread within this host.
+
+ Local ordinal is in range [0, local_device_count)."""
+ local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0)
+ devices_per_process = addressable_device_count()
+ return local_rank * devices_per_process + xla_device().index
+
+
+def process_index() -> int:
+ return torch_xla._XLAC._xla_get_process_index()
+
+
+def process_count() -> int:
+ return torch_xla._XLAC._xla_get_num_processes()
+
+
+def host_index() -> int:
+ if plugins.using_dynamic_plugins():
+ return plugins.default().host_index()
+ elif device_type() == 'TPU':
+ return tpu.worker_id()
+
+ # TODO: Update this when we support multi-host GPU
+ return 0
+
+
+# API below will be used to query physcial device attribute.
+def runtime_device_attributes(device: str) -> Dict[str, object]:
+ return torch_xla._XLAC._xla_get_device_attributes(device)
+
+
+def global_runtime_device_attributes() -> List[Dict[str, object]]:
+ return torch_xla._XLAC._xla_get_all_device_attributes()
+
+
+[docs]@functools.lru_cache()
+def global_runtime_device_count() -> int:
+ """Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD."""
+ return len(torch_xla._XLAC._xla_get_all_runtime_devices())
+
+
+def addressable_runtime_device_count() -> int:
+ """Returns the number of devices visible to this process."""
+ return torch_xla._XLAC._xla_num_runtime_devices()
+
+
+# TODO(yeounoh) introduce SPMD configuration.
+[docs]def use_spmd(auto: Optional[bool] = False):
+ """API to enable SPMD mode. This is a recommended way to enable SPMD.
+
+ This forces SPMD mode if some tensors are already initialized on non-SPMD
+ devices. This means that those tensors would be replicated across the devices.
+
+ Args:
+ auto (bool): Whether to enable the auto-sharding. Read
+ https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding
+ for more detail
+ """
+ if os.environ.get("XLA_USE_SPMD") is not None:
+ warnings.warn("XLA_USE_SPMD is being deprecated. "
+ "Use torch_xla.runtime.use_spmd() "
+ "without setting XLA_USE_SPMD env-var.")
+
+ if torch_xla._XLAC._xla_get_spmd_config_is_locked(
+ ) and not xu.check_env_flag("XLA_USE_SPMD"):
+ warnings.warn(
+ "Replicating tensors already initialized on non-virtual XLA device for SPMD "
+ "to force SPMD mode. This is one-time overhead to setup, and to minimize such, "
+ "please set SPMD mode before initializting tensors "
+ "(i.e., call use_spmd() in the beginning of the program).")
+ torch_xla._XLAC._xla_force_spmd_device()
+ xm.wait_device_ops()
+
+ # TODO(yeounoh) we can drop envvar in the future
+ os.environ["XLA_USE_SPMD"] = "1"
+ if auto:
+ torch_xla._XLAC._xla_set_auto_sharding()
+ os.environ["XLA_AUTO_SPMD"] = "1"
+
+ if device_type() == 'NEURON':
+ # In case of Neuron, reset the initialization environment to accommodate SPMD.
+ try:
+ from torch_neuronx.initialization import initialize
+
+ initialize()
+ except ImportError:
+ pass
+
+
+[docs]def is_spmd():
+ """Returns if SPMD is set for execution."""
+ # TODO(yeounoh) replace this when we fully deprecate the flag.
+ return xu.check_env_flag('XLA_USE_SPMD')
+
+
+[docs]def get_master_ip() -> str:
+ """Retrieve the master worker IP for the runtime. This calls into
+ backend-specific discovery APIs.
+
+ Returns:
+ master worker's IP address as a string."""
+ if device_type() == 'TPU':
+ return tpu.discover_master_worker_ip()
+ raise RuntimeError(f'IP discovery not supported for device: {device_type()}')
+
+
+[docs]def initialize_cache(path: str, readonly: bool = False):
+ """Initializes the persistent compilation cache. This API must be called
+ before any computations have been performed.
+
+ Args:
+ path (str): The path at which to store the persistent cache.
+ readonly (bool): Whether or not this worker should have write access to the cache.
+ """
+ assert not torch_xla._XLAC._xla_computation_cache_is_initialized(
+ ), "Computation cache has already been initialized"
+
+ # TODO(jonbolin): Consider moving away from environment variables to control
+ # the cache.
+ os.environ['XLA_PERSISTENT_CACHE_PATH'] = path
+ os.environ['XLA_PERSISTENT_CACHE_READ_ONLY'] = '1' if readonly else '0'
+
+import sys
+import collections
+import contextlib
+import functools
+import uuid
+from typing import Any, Callable, List, Optional, Tuple
+import weakref
+
+import torch
+import torch.distributed as dist
+import torch_xla
+import torch_xla.core.xla_model as xm
+import torch_xla.core.xla_env_vars as xenv
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.runtime as xr
+import torch_xla.utils.utils as xu
+
+
+[docs]def device(index: int = None) -> torch.device:
+ """Returns a given instance of an XLA device.
+
+ If SPMD enables, returns a virtual device that wraps all devices available
+ to this process.
+
+ Args:
+ index: index of the XLA device to be returned. Corresponds to index in
+ `torch_xla.devices()`.
+
+ Returns:
+ An XLA `torch.device`.
+ """
+
+ return xm.xla_device(index)
+
+
+[docs]def devices() -> List[torch.device]:
+ """Returns all devices available in the current process.
+
+ Returns:
+ A list of XLA `torch.devices`.
+ """
+
+ return [torch.device(d) for d in xm.get_xla_supported_devices()]
+
+
+def real_devices() -> List[str]:
+ """Returns local XLA device types and indices.
+
+ Returns:
+ A list strings representing the XLA devices available in the current
+ process, e.g. `['TPU:0', 'TPU:1', ...]`.
+ """
+
+ return torch_xla._XLAC._xla_real_devices()
+
+
+[docs]def device_count() -> int:
+ """Returns number of addressable devices in the current process."""
+ return len(real_devices())
+
+
+[docs]def sync(wait: bool = False):
+ """Launches all pending graph operations.
+
+ Args:
+ wait (bool): whether to block the current process until the execution finished.
+
+ """
+ torch_xla._XLAC._xla_step_marker(
+ torch_xla._XLAC._xla_get_default_device(),
+ [],
+ wait=wait,
+ )
+ devctx = xm._run_step_closures()
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, None)
+
+
+def step():
+ """Wraps code that should be dispatched to the runtime.
+
+ Experimental: `xla.step` is still a work in progress. Some code that currently
+ works with `xla.step` but does not follow best practices will become errors in
+ future releases. See https://github.com/pytorch/xla/issues/6751 for context.
+ """
+ return compile()
+
+
+# Keeps track of the alive functions. This allow us to remove session entries in the
+# C++ side for functions that are no longer alive.
+_compiled_id_to_functions_ref = weakref.WeakValueDictionary()
+
+
+[docs]def compile(
+ f: Optional[Callable] = None,
+ full_graph: Optional[bool] = False,
+ name: Optional[str] = None,
+ num_different_graphs_allowed: Optional[int] = None,
+):
+ """
+ Optimizes given model/function using torch_xla's LazyTensor tracing mode.
+ PyTorch/XLA will trace the given function with given inputs and then generate
+ graphs to represent the pytorch operations happens within this function. This
+ graph will be compiled by the XLA and executed on the accelerator(decided by the
+ tensor's device). Eager mode will be disabled for the compiled region of the funciton.
+
+ Args:
+ model (Callable): Module/function to optimize, if not passed this function will
+ act as a context manager.
+ full_graph (Optional[bool]): Whether this compile should generate a single graph. If set to True
+ and multiple graphs will be generated torch_xla will throw an error with debug info
+ and exit.
+ name (Optional[name]): Name of the compiled program. The name of the function `f` will be used
+ if not specified. This name will be used in the `PT_XLA_DEBUG` messages as well as HLO/IR dump
+ file.
+ num_different_graphs_allowed (Optional[int]): number of different traced graphs of the given
+ model/function that we are allowed to have. An error will be raised in case this limit
+ is exceeded.
+
+ Example::
+
+ # usage 1
+ @torch_xla.compile()
+ def foo(x):
+ return torch.sin(x) + torch.cos(x)
+
+ def foo2(x):
+ return torch.sin(x) + torch.cos(x)
+ # usage 2
+ compiled_foo2 = torch_xla.compile(foo2)
+
+ # usage 3
+ with torch_xla.compile():
+ res = foo2(x)
+ """
+ if name is None and f is not None:
+ if hasattr(f, '__name__'):
+ name = f.__name__
+ elif hasattr(f, '__str__'):
+ name = f.__str__()
+
+ if f is not None:
+ current_id = f"{name}_{id(f)}"
+ else:
+ current_id = str(uuid.uuid4())
+
+ # Check whether the function/module that corresponds with current_id is still alive. If it's not,
+ # we can remove it from the session's map in the C++ side, so we can start a fresh session.
+ #
+ # This solves the issue where there are 2 different local-scoped functions with the same name.
+ # Since they are local-scoped, they might end-up with the same id. And, since they have the same
+ # name, their current_id will be the same, even though they are different functions.
+ #
+ # This issue was observed when running test_dynamic_shape_detector.py.
+ if current_id not in _compiled_id_to_functions_ref:
+ torch_xla._XLAC._dynamic_shape_detector_remove_session(current_id)
+
+ if f is not None:
+ _compiled_id_to_functions_ref[current_id] = f
+
+ def _clear_pending_ops_before_compile():
+ sync()
+
+ @contextlib.contextmanager
+ def _compile():
+ saved_eager_mode_status = torch_xla._XLAC._get_use_eager_mode()
+ saved_allow_execution = torch_xla._XLAC._get_allow_execution()
+ saved_current_graph_name = torch_xla._XLAC._get_current_graph_name()
+ torch_xla._XLAC._set_use_eager_mode(False)
+ if name is not None:
+ torch_xla._XLAC._set_current_graph_name(name + '_clear_pending')
+ # Clear pending operations
+ _clear_pending_ops_before_compile()
+
+ if name is not None:
+ torch_xla._XLAC._set_current_graph_name(name)
+
+ # if full_graph sets to true execution can not happen before the sync below
+ torch_xla._XLAC._set_allow_execution(not full_graph)
+
+ if num_different_graphs_allowed is not None:
+ torch_xla._XLAC._dynamic_shape_detector_set_max_num_different_graphs_allowed(
+ num_different_graphs_allowed)
+ torch_xla._XLAC._dynamic_shape_detector_start_session(current_id)
+
+ try:
+ yield
+ finally:
+ torch_xla._XLAC._set_allow_execution(saved_allow_execution)
+ if num_different_graphs_allowed is not None:
+ torch_xla._XLAC._dynamic_shape_detector_end_session()
+ # Collect the traced graph after running the target function and
+ # execute the graph.
+ sync()
+ torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status)
+ torch_xla._XLAC._set_current_graph_name(saved_current_graph_name)
+
+ return _compile() if f is None else _compile()(f)
+
+
+[docs]def manual_seed(seed, device=None):
+ """Set the seed for generating random numbers for the current XLA device.
+
+ Args:
+ seed (integer): The state to be set.
+ device (torch.device, optional): The device where the RNG state needs to be set.
+ If missing the default device seed will be set.
+ """
+ xm.set_rng_state(seed, device)
+
+
+# TODO(wcromar): Update args to type ParamSpec.
+def launch(
+ fn: Callable,
+ args: Tuple = (),
+ start_method: str = 'spawn',
+ debug_single_process: bool = False,
+):
+ """ Entry to launch multiprocess.
+
+ Raises:
+ NotImplementedError: SPMD is not supported yet.
+ """
+ if xr.is_spmd():
+ # TODO(piz): SPMD is specified differently from mp. Skip for now.
+ raise NotImplementedError(
+ 'launch function does not support SPMD at this time')
+
+ nprocs = 1 if debug_single_process else None
+
+ if dist.is_torchelastic_launched():
+ fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
+ else:
+ xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)
+