-
+
- + Get Started + + +
- + Ecosystem + + +
-
+ + + Edge + + ++
+
+
- + Blog + + +
- + Tutorials + + +
- + + + +
diff --git a/release/2.2/_modules/index.html b/release/2.2/_modules/index.html new file mode 100644 index 00000000000..0cf71fc10d8 --- /dev/null +++ b/release/2.2/_modules/index.html @@ -0,0 +1,658 @@ + + + + + + +
+ + + + +
+import torch
+import torch_xla
+import torch_xla.core.xla_model as xm
+
+
+class AllReduce(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, reduce_type, scale, groups):
+ ctx.reduce_type = reduce_type
+ ctx.scale = scale
+ output = xm.all_reduce(reduce_type, input, scale=scale, groups=groups)
+ ctx.save_for_backward(input, output)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, output = ctx.saved_tensors
+ grad = grad_output * ctx.scale if ctx.scale != 1.0 else grad_output
+ if ctx.reduce_type == xm.REDUCE_SUM:
+ return grad, None, None, None
+ if ctx.reduce_type == xm.REDUCE_MUL:
+ # MUL is not supported by TPU
+ grad_scaler = torch.where(input != 0, output / input,
+ torch.zeros_like(input))
+ return grad * grad_scaler, None, None, None
+ if ctx.reduce_type == xm.REDUCE_MIN or ctx.reduce_type == xm.REDUCE_MAX:
+ return torch.where(input == output, grad,
+ torch.zeros_like(grad)), None, None, None
+ raise RuntimeError('Unsupported reduce type: {}'.format(ctx.reduce_type))
+
+
+[docs]def all_reduce(reduce_type, value, scale=1.0, groups=None):
+ """Performs an inplace reduce operation on the input tensor.
+
+ This is the same as `xm.all_reduce()` but supports autograd differentiation.
+
+ Args:
+ reduce_type (string): One of ``REDUCE_SUM``, ``REDUCE_MUL``, ``REDUCE_AND``,
+ ``REDUCE_OR``, ``REDUCE_MIN`` and ``REDUCE_MAX``.
+ value (torch.Tensor): The to perform the all reduce op to.
+ scale (float): A default scaling value to be applied after the reduce.
+ Default: 1.0
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ Returns:
+ The reduced value across the selected replicas.
+ """
+ return AllReduce.apply(value, reduce_type, scale, groups)
+
+
+class AllGather(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, dim):
+ ctx.dim = dim
+ ctx.ordinal = xm.get_ordinal()
+ ctx.world_size = xm.xrt_world_size()
+ return xm.all_gather(input, dim=dim)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ slice_size = grad_output.size(ctx.dim) // ctx.world_size
+ return torch.narrow(grad_output.clone(), ctx.dim, ctx.ordinal * slice_size,
+ slice_size), None
+
+
+[docs]def all_gather(value, dim=0):
+ """Performs an all-gather operation along a given dimension.
+
+ This is the same as `xm.all_gather()` but supports autograd differentiation.
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ dim (int): The gather dimension.
+ Default: 0
+ Returns:
+ A tensor which has, in the ``dim`` dimension, all the values from the
+ participating replicas.
+ """
+ return AllGather.apply(value, dim)
+
+
+[docs]def nms(boxes, scores, score_threshold, iou_threshold, output_size):
+ """Performs a Non Maximal Suppression operation.
+
+ Args:
+ boxes (torch.Tensor): A `torch.Tensor` of shape `[N, 4]` listing the boxes
+ coordinates in `(y0, x0, y1, x1)` form.
+ scores (torch.Tensor): A `torch.Tensor` of shape `[N]` listing the scores
+ of each box.
+ score_threshold (torch.Tensor): The minimum score for a box to qualify as
+ valid.
+ iou_threshold (torch.Tensor): The minimum IOU (Intersection Over Union)
+ score to trigger overlap logic.
+ output_size (int): The maximum number of returned indices (must be lower or
+ equal to N).
+
+ Returns:
+ A tuple of `torch.Tensor` with the first element being the selected box
+ indices, and the second element being the number of valid boxes.
+ """
+ return torch_xla._XLAC._xla_nms(boxes, scores, score_threshold, iou_threshold,
+ output_size)
+
+
+def distributed_mm(w, x, split=1):
+ """Performs a matrix multiplication with sharded weight.
+
+ Args:
+ w (torch.Tensor): The sharded weight, RHS of the matrix multiplication
+ operation. The weight shape is `N x Ko` where `Ko` is the shard
+ dimension size. Each ordinal will have its own copy of the weight.
+ x (torch.Tensor): The input tensor, LHS of the matrix multiplication
+ operation. The input shape is `WG x M` where `WG = Ko * WORLD_SIZE`.
+ split (int): The number of splits for the `M` dimension of `x`. Since
+ there is an `all_gather()` on such dimension, if `M` is big, a split
+ might be required in order to fit device memory.
+ Default: 1
+ Returns:
+ The result of the distributed matrix multiplication operation.
+ """
+ ordinal = xm.get_ordinal()
+ # w = N x Ko
+ # WG = Ko * WORLD_SIZE
+ # x = WG x M
+ assert x.size(0) // xm.xrt_world_size() == w.size(1)
+ splits = []
+ if split != 1:
+ size = x.size(1)
+ assert size % split == 0
+ split_size = size // split
+ splits = torch.split(x, split_size, dim=1)
+ else:
+ splits.append(x)
+ results = []
+ for xs in splits:
+ # xg = WG x (M * WORLD_SIZE)
+ xg = all_gather(xs, dim=1)
+ # xgn = Ko x (M * WORLD_SIZE)
+ xgn = torch.narrow(xg, 0, ordinal * w.size(1), w.size(1))
+ # wxg = N x (M * WORLD_SIZE)
+ wxg = w @ xgn
+ # rwxg = N x (M * WORLD_SIZE)
+ rwxg = all_reduce(xm.REDUCE_SUM, wxg)
+ # wx = N x M
+ wx = torch.narrow(rwxg, 1, ordinal * xs.size(1), xs.size(1))
+ results.append(wx)
+ return torch.cat(results, dim=1) if len(results) > 1 else results[0]
+
+
+class SyncBatchNorm(torch.nn.Module):
+
+ def __init__(
+ self,
+ num_features: int,
+ eps: float = 1e-5,
+ momentum: float = 0.1,
+ ):
+ super().__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.weight = torch.nn.Parameter(torch.ones(num_features))
+ self.bias = torch.nn.Parameter(torch.zeros(num_features))
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
+ assert 2 <= batch.ndim <= 5 and batch.shape[1] == self.num_features
+ reduce_dims = list(range(batch.ndim))
+ reduce_dims.pop(1) # channel dim
+
+ if self.training:
+ local_mean = torch.mean(batch, dim=reduce_dims)
+ local_sqr_mean = torch.mean(batch * batch, dim=reduce_dims)
+
+ scale = 1.0 / xm.xrt_world_size()
+ mean = AllReduceSumLayer.apply(local_mean) * scale
+ sqr_mean = AllReduceSumLayer.apply(local_sqr_mean) * scale
+ var = sqr_mean - mean.pow(2)
+
+ self.running_mean = (
+ 1 - self.momentum) * self.running_mean + self.momentum * mean
+ self.running_var = (
+ 1 - self.momentum) * self.running_var + self.momentum * var
+ else:
+ mean = self.running_mean
+ var = self.running_var
+
+ res = torch.empty_like(batch)
+ for c in range(self.num_features):
+ if batch.ndim == 2:
+ res = ((batch - mean) /
+ torch.sqrt(var + self.eps)) * self.weight + self.bias
+ else:
+ res[:, c, ...] = (
+ (batch[:, c, ...] - mean[c]) /
+ torch.sqrt(var[c] + self.eps)) * self.weight[c] + self.bias[c]
+
+ return res
+
+ def extra_repr(self) -> str:
+ return f'{self.num_features}, eps={self.eps}'
+
+
+class AllReduceSumLayer(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x):
+ return xm.all_reduce(xm.REDUCE_SUM, x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return xm.all_reduce(xm.REDUCE_SUM, grad_output)
+
+import io
+import itertools
+import logging
+import sys
+import re
+import threading
+import time
+import warnings
+from typing import List, Optional
+import torch
+import torch.distributed._functional_collectives
+import torch.nn.functional as F
+import torch_xla
+from torch_xla import runtime
+import torch_xla.core.xla_env_vars as xenv
+import torch_xla.debug.metrics_saver as ms
+import torch_xla.utils.utils as xu
+import torch_xla.utils.closures as xc
+
+_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
+
+REDUCE_SUM = 'sum'
+REDUCE_MUL = 'mul'
+REDUCE_AND = 'and'
+REDUCE_OR = 'or'
+REDUCE_MIN = 'min'
+REDUCE_MAX = 'max'
+
+_DEVICE_CONTEXTS = dict()
+_DEVICE_CONTEXTS_LOCK = threading.Lock()
+
+# Note [Dynamo WORLD_SIEZ and ORDINAL]
+# Belows are workaround to cache the ordinal and world_size such that
+# Dynamo won't do graph breaks when xm.xrt_world_size() and xm.get_ordinal() are called.
+_WORLD_SIZE = None
+_ORDINAL = None
+
+
+def _init_world_size_ordinal():
+ global _WORLD_SIZE, _ORDINAL
+
+ # Dynamo doesn't support XRT or multithreaded runtime. See Note [V3-8 Threading]
+ if not runtime.using_pjrt() or runtime.addressable_device_count() > 1:
+ return
+
+ if _WORLD_SIZE is None:
+ _WORLD_SIZE = xrt_world_size()
+ _ORDINAL = get_ordinal()
+
+
+class DeviceContext(object):
+
+ def __init__(self, device):
+ self.device = device
+
+
+def _get_device_context(device=None):
+ if device is None:
+ device = torch_xla._XLAC._xla_get_default_device()
+ else:
+ device = str(device)
+ with _DEVICE_CONTEXTS_LOCK:
+ devctx = _DEVICE_CONTEXTS.get(device, None)
+ if devctx is None:
+ devctx = DeviceContext(device)
+ _DEVICE_CONTEXTS[device] = devctx
+ return devctx
+
+
+def is_xla_tensor(tensor):
+ return tensor.device.type == 'xla'
+
+
+def parse_xla_device(device):
+ m = re.match(r'(CPU|TPU|GPU|ROCM|CUDA|XPU|NEURON):(\d+)$', device)
+ if m:
+ return (m.group(1), int(m.group(2)))
+
+
+[docs]def get_xla_supported_devices(devkind=None, max_devices=None):
+ """Returns a list of supported devices of a given kind.
+
+ Args:
+ devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU`,
+ `NEURON` or `CPU` (the 'GPU' XLA device is currently not implemented).
+ max_devices (int, optional): The maximum number of devices to be returned of
+ that kind.
+
+ Returns:
+ The list of device strings.
+ """
+ # TODO(xiowei replace gpu with cuda): Remove the below if statement after r2.2 release.
+ if devkind and devkind.casefold() == 'gpu':
+ warnings.warn(
+ "GPU as a device name is being deprecate. Please replace it with CUDA such as get_xla_supported_devices(devkind='CUDA'). Similarly, please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA."
+ )
+ devkind = 'CUDA'
+
+ xla_devices = _DEVICES.value
+ devkind = [devkind] if devkind else [
+ 'TPU', 'GPU', 'XPU', 'NEURON', 'CPU', 'CUDA', 'ROCM'
+ ]
+ for kind in devkind:
+ kind_devices = []
+ for i, device in enumerate(xla_devices):
+ if re.match(kind + r':\d+$', device):
+ kind_devices.append('xla:{}'.format(i))
+ if kind_devices:
+ return kind_devices[:max_devices] if max_devices else kind_devices
+
+
+[docs]def xrt_world_size(defval=1):
+ """Retrieves the number of devices which is taking part of the replication.
+
+ Args:
+ defval (int, optional): The default value to be returned in case there is no
+ replication information available.
+ Default: 1
+
+ Returns:
+ The number of devices which is taking part of the replication.
+ """
+ global _WORLD_SIZE
+ if _WORLD_SIZE is not None:
+ return _WORLD_SIZE
+
+ return runtime.world_size()
+
+
+[docs]def get_ordinal(defval=0):
+ """Retrieves the replication ordinal of the current thread.
+
+ The ordinals range from 0 to `xrt_world_size()` minus 1.
+
+ Args:
+ defval (int, optional): The default value to be returned in case there is no
+ replication information available. Ignored for runtime.
+ Default: 0
+
+ Returns:
+ The replication ordinal of the current thread.
+ """
+ global _ORDINAL
+ if _ORDINAL is not None:
+ return _ORDINAL
+
+ return runtime.global_ordinal()
+
+
+[docs]def get_local_ordinal(defval=0):
+ """Retrieves the replication local ordinal of the current thread.
+
+ The local ordinals range from 0 to the number of local devices minus 1.
+
+ Args:
+ defval (int, optional): The default value to be returned in case there is no
+ replication information available. Ignored for runtime.
+ Default: 0
+
+ Returns:
+ The replication local ordinal of the current thread.
+ """
+ return runtime.local_ordinal()
+
+
+[docs]def is_master_ordinal(local=True):
+ """Checks whether the current process is the master ordinal (0).
+
+ Args:
+ local (bool): Whether the local or global master ordinal should be checked.
+ In case of multi-host replication, there is only one global master ordinal
+ (host 0, device 0), while there are NUM_HOSTS local master ordinals.
+ Default: True
+
+ Returns:
+ A boolean indicating whether the current process is the master ordinal.
+ """
+ ordinal = get_local_ordinal() if local else get_ordinal()
+ return ordinal == 0
+
+
+def master_print(*args, fd=sys.stdout, local=False, flush=False):
+ if is_master_ordinal(local=local):
+ print(*args, file=fd, flush=flush)
+
+
+[docs]def xla_device(n=None, devkind=None):
+ """Returns a given instance of an XLA device.
+
+ Args:
+ n (int, optional): The specific instance (ordinal) to be returned. If
+ specified, the specific XLA device instance will be returned. Otherwise
+ the first device of `devkind` will be returned.
+ devkind (string..., optional): If specified, one of `TPU`, `CUDA`, `XPU`
+ `NEURON`, `ROCM` or `CPU`.
+
+ Returns:
+ A `torch.device` with the requested instance.
+ """
+ # When SPMD is enabled, we always return `xla:0` to the user, and
+ # under the hood we use virtual device logic for every xla tensor
+ if xu.check_env_flag('XLA_USE_SPMD'):
+ device = 'xla:0'
+ torch_xla._XLAC._xla_set_default_device(device)
+ return torch.device(device)
+
+ return runtime.xla_device(n, devkind)
+
+
+def _xla_real_device(device):
+ device_str = str(device)
+ m = re.match(r'xla:(\d+)$', device_str)
+ if not m:
+ raise RuntimeError('Invalid device format: {}'.format(device_str))
+ return _DEVICES.value[int(m.group(1))]
+
+
+def xla_real_devices(devices):
+ return [_xla_real_device(device) for device in devices]
+
+
+[docs]def xla_device_hw(device):
+ """Returns the hardware type of the given device.
+
+ Args:
+ device (string or torch.device): The xla device that will be mapped to the
+ real device.
+
+ Returns:
+ A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `NEURON`, `GPU`, `CUDA`, `ROCM`)
+ of the given device.
+ """
+ real_device = _xla_real_device(device)
+ return real_device.split(':')[0]
+
+
+def xla_replication_devices(local_devices):
+ real_devices = xla_real_devices(local_devices)
+ device_types = set()
+ for device in real_devices:
+ xdev = parse_xla_device(device)
+ device_types.add(xdev[0])
+ if len(device_types) != 1:
+ # No replication if the device set spawns multiple device types.
+ raise RuntimeError(
+ 'Cannot replicate across different device types: devices={}/{}'.format(
+ local_devices, real_devices))
+ device_type = device_types.pop()
+ kind_devices = get_xla_supported_devices(devkind=device_type)
+ if len(kind_devices) != len(local_devices):
+ # Replication can only happen among all devices of one kind.
+ raise RuntimeError(
+ 'Cannot replicate if number of devices ({}) is different from {}'.
+ format(len(local_devices), len(kind_devices)))
+ replication_devices = []
+ for device in torch_xla._XLAC._xla_get_all_devices():
+ xdev = parse_xla_device(device)
+ if not xdev:
+ raise RuntimeError('Invalid device format: {}'.format(device))
+ if xdev[0] == device_type:
+ replication_devices.append(device)
+ sorted_by_ordinal = sorted(
+ replication_devices, key=lambda device: parse_xla_device(device)[1])
+ return sorted_by_ordinal
+
+
+def unlazy(tensors):
+ """Blocks the program until `tensors` are materialized.
+
+ This API is for benchmarking, don't use it in real models.
+
+ Args:
+ tensors (List[torch.Tensor]): List of `torch.Tensor`s to materialize. For
+ each Tensor `t` in the list, `t.device` must be an `xla` device.
+ """
+ torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True)
+
+
+def set_replication(device, devices):
+ device = str(device)
+ devctx = _get_device_context(device=device)
+ devices = [str(x) for x in devices]
+ if devices:
+ replication_devices = xla_replication_devices(devices)
+ torch_xla._XLAC._xla_set_replication_devices(replication_devices)
+ devctx.device_index = devices.index(device)
+ else:
+ torch_xla._XLAC._xla_set_replication_devices([])
+ devctx.device_index = 0
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, None)
+ torch_xla._XLAC._xla_set_default_device(device)
+
+
+class RateTracker(object):
+
+ def __init__(self, smooth_factor=None):
+ self._smooth_factor = xu.getenv_as(
+ 'RATE_TRACKER_SMOOTHING', float,
+ 0.4) if smooth_factor is None else smooth_factor
+ self._start_time = time.time()
+ self._partial_time = self._start_time
+ self._partial_count = 0.0
+ self._partial_rate = None
+ self._count = 0.0
+
+ def _update(self, now, rate):
+ self._partial_count += self._count
+ self._count = 0.0
+ self._partial_time = now
+ self._partial_rate = rate
+
+ def add(self, count):
+ self._count += count
+
+ def _smooth(self, current_rate):
+ if self._partial_rate is None:
+ smoothed_rate = current_rate
+ else:
+ smoothed_rate = ((1 - self._smooth_factor) * current_rate +
+ self._smooth_factor * self._partial_rate)
+ return smoothed_rate
+
+ def rate(self):
+ now = time.time()
+ delta = now - self._partial_time
+ report_rate = 0.0
+ if delta > 0:
+ report_rate = self._smooth(self._count / delta)
+ self._update(now, report_rate)
+ return report_rate
+
+ def global_rate(self):
+ delta = time.time() - self._start_time
+ count = self._partial_count + self._count
+ return count / delta if delta > 0 else 0.0
+
+
+class ToXlaTensorArena(object):
+
+ def __init__(self, convert_fn, select_fn):
+ self._convert_fn = convert_fn
+ self._select_fn = select_fn
+ self._tensors = []
+
+ def _add(self, tensor):
+ self._tensors.append(tensor)
+
+ def _convert(self):
+ self._index = 0
+ if self._tensors:
+ self._converted_tensors = self._convert_fn(self._tensors)
+ else:
+ self._converted_tensors = []
+
+ def _get_converted_tensor(self):
+ assert self._index < len(self._converted_tensors)
+ new_tensor = self._converted_tensors[self._index]
+ self._index += 1
+ return new_tensor
+
+ def _collect_tensors(self, inputs):
+
+ def collect_fn(value):
+ self._add(value)
+
+ xu.for_each_instance(inputs, lambda x: self._select_fn(x), collect_fn)
+
+ def _replace_tensors(self, inputs):
+
+ def convert_fn(value):
+ return self._get_converted_tensor()
+
+ return xu.for_each_instance_rewrite(inputs, lambda x: self._select_fn(x),
+ convert_fn)
+
+ def transform(self, inputs):
+ self._tensors = []
+ self._collect_tensors(inputs)
+ self._convert()
+ return self._replace_tensors(inputs)
+
+
+def check_view_sharing(obj):
+ tensors = set()
+ aliases = dict()
+
+ def tensor_info(t):
+ return '{}{}'.format(t.dtype, list(t.size()))
+
+ def tensor_id(t):
+ if is_xla_tensor(t):
+ return torch_xla._XLAC._xla_get_tensor_id(t), 'xla'
+ return id(t), 'torch'
+
+ def alias_id(t):
+ if is_xla_tensor(t):
+ aid = torch_xla._XLAC._xla_get_tensor_view_alias_id(t)
+ return None if aid == 0 else aid, 'xla'
+ return t.storage().data_ptr(), 'torch'
+
+ def check_object(obj):
+ tid = tensor_id(obj)
+ if tid not in tensors:
+ tensors.add(tid)
+ aid = alias_id(obj)
+ if aid[0] is not None:
+ if aid in aliases:
+ oobj = aliases[aid]
+ raise RuntimeError(
+ 'Tensor ID {} ({}) is sharing a view with tensor ID {} ({})'.
+ format(tid, tensor_info(obj), tensor_id(oobj), tensor_info(oobj)))
+ aliases[aid] = obj
+
+ xu.for_each_instance(obj, lambda x: type(x) == torch.Tensor, check_object)
+
+
+def _fetch_gradients(optimizer):
+ gradients = []
+ for param_group in optimizer.__getstate__()['param_groups']:
+ for group, params in param_group.items():
+ if group == 'params':
+ for p in params:
+ if isinstance(p, torch.Tensor) and p.grad is not None:
+ gradients.append(p.grad.data)
+ return gradients
+
+
+def _get_all_reduce_token():
+ devctx = _get_device_context()
+ token = torch_xla._XLAC._get_all_reduce_token(devctx.device)
+ return token, devctx
+
+
+[docs]def all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True):
+ """Performs an inplace reduce operation on the input tensor(s).
+
+ Args:
+ reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
+ ``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
+ ``xm.REDUCE_MAX``.
+ inputs: Either a single `torch.Tensor` or a list of `torch.Tensor` to
+ perform the all reduce op to.
+ scale (float): A default scaling value to be applied after the reduce.
+ Default: 1.0
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ If a single `torch.Tensor` is passed, the return value is a `torch.Tensor`
+ holding the reduced value (across the replicas). If a list/tuple is passed,
+ this function performs an inplace all-reduce op on the input tensors, and
+ returns the list/tuple itself.
+ """
+ groups = groups or []
+
+ # No-op if there is only one device
+ if xrt_world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE', bool,
+ False):
+ if isinstance(inputs, torch.Tensor):
+ return inputs.clone()
+ else:
+ return inputs
+
+ if isinstance(inputs, torch.Tensor):
+ result = None
+ if scale == 1.0 and groups == [] and pin_layout:
+ # TODO(alanwaketan): Support groups.
+ # Only c10d_functional version cc ops are traceable by Dynamo.
+ result = torch.ops.c10d_functional.all_reduce(inputs, reduce_type, "", [],
+ 0)
+ else:
+ result = torch_xla._XLAC._xla_all_reduce(reduce_type, inputs, scale,
+ groups, pin_layout)
+ results = [result]
+ else:
+ torch_xla._XLAC._xla_all_reduce_inplace(reduce_type, inputs, scale, groups,
+ pin_layout)
+ results = inputs
+
+ return results[0] if isinstance(inputs, torch.Tensor) else results
+
+
+def _all_gather_using_all_reduce(value, dim=0, groups=None, pin_layout=True):
+ """Performs an all-gather operation using all-reduce along a given dimension.
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ dim (int): The gather dimension.
+ Default: 0
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_gather()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ output (torch.Tensor): Optional output tensor.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ A tensor which has, in the ``dim`` dimension, all the values from the
+ participating replicas.
+ """
+ if dim < 0:
+ dim = value.dim() + dim
+ size = value.size(dim)
+ padding = [0] * (2 * value.dim())
+ ordinal = get_ordinal()
+ if groups is None:
+ left, right = ordinal, xrt_world_size() - 1 - ordinal
+ else:
+ ordinals = dict()
+ for g in groups:
+ for i, x in enumerate(g):
+ ordinals[x] = (i, len(g) - 1 - i)
+ left, right = ordinals[ordinal]
+ idx = value.dim() - 1 - dim
+ padding[2 * idx] = left * size
+ padding[2 * idx + 1] = right * size
+ return all_reduce(REDUCE_SUM, F.pad(value, padding), groups=groups)
+
+
+[docs]def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
+ """Performs an all-gather operation along a given dimension.
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ dim (int): The gather dimension.
+ Default: 0
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_gather()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ output (torch.Tensor): Optional output tensor.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ A tensor which has, in the ``dim`` dimension, all the values from the
+ participating replicas.
+ """
+ # _all_gather_using_all_reduce does not support list of tensors as input
+ if pin_layout and output == None and isinstance(value, torch.Tensor):
+ # There is not an easy way to pin the all_gather layout on TPU, GPU and NEURON,
+ # use all_reduce based all_gather for this purpose.
+ return _all_gather_using_all_reduce(
+ value, dim=dim, groups=groups, pin_layout=True)
+
+ if dim < 0:
+ dim = value.dim() + dim
+ if groups:
+ shard_count = len(groups[0])
+ assert all(len(group) == shard_count for group in groups), \
+ "Replica groups must have the same number of replicas/shards."
+ else:
+ # All replicas belong to a single group
+ shard_count = xrt_world_size()
+
+ token, devctx = _get_all_reduce_token()
+
+ if isinstance(value, torch.Tensor):
+ if output != None:
+ # Call the out of place version of the all_gather
+ new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
+ shard_count, groups or [],
+ pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return output
+
+ result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
+ [], pin_layout)
+ return result
+
+ # Now the input should be a list of Tensors.
+ elif isinstance(value, list) and all(
+ isinstance(v, torch.Tensor) for v in value):
+ if pin_layout:
+ raise RuntimeError(
+ "For xm.all_gather with list of tensors input, pin_layout=True is not yet supported."
+ )
+ if output != None:
+ if not isinstance(output, list) or any(
+ not isinstance(v, torch.Tensor) for v in output):
+ raise TypeError(
+ f"`output` needs to be a list of Tensors, but given {type(output)}."
+ )
+ if len(output) != len(value):
+ raise ValueError("`output` length doesn't match `input` length: "
+ f"{len(output)} vs {len(input)}.")
+ # Call the out of place version of the reduce_scatter
+ new_token = torch_xla._XLAC._xla_all_gather_coalesced_out(
+ output, value, token, dim, shard_count, groups or [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return output
+
+ result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
+ shard_count, groups or
+ [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
+ return result[:-1]
+ else:
+ raise TypeError("`value` needs to be a Tensor or a list of Tensors, but "
+ f"given {type(value)}.")
+
+
+[docs]def all_to_all(value,
+ split_dimension,
+ concat_dimension,
+ split_count,
+ groups=None,
+ pin_layout=True):
+ """Performs an XLA `AllToAll()` operation on the input tensor.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#alltoall
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ split_dimension (int): The dimension upon which the split should happen.
+ concat_dimension (int): The dimension upon which the concat should happen.
+ split_count (int): The split count.
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ The result `torch.Tensor` of the `all_to_all()` operation.
+ """
+ token, devctx = _get_all_reduce_token()
+ result = torch_xla._XLAC._xla_all_to_all(value, token, split_dimension,
+ concat_dimension, split_count,
+ groups or [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
+ return result[0]
+
+
+def collective_permute(value, pairs):
+ """Performs a XLA `CollectivePermute()` operation on the input tensor.
+
+ WARNING: This function is not very reliable, may produce wrong results under
+ certain inputs. Use it at your own risk.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ pairs (list): A list of (source_replica_id, target_replica_id) pairs,
+ representing the sender and receiver for the `collective_permute()`
+ operation. Example: `[[0, 1], [1, 2], [2, 0]]` defines three pairs. The
+ tensor will be sent from replica 0 to replica 1, replica 1 to replica 2,
+ and replica 2 to replica 0.
+
+ Returns:
+ The result `torch.Tensor` of the `collective_permute()` operation.
+ """
+ token, devctx = _get_all_reduce_token()
+ result = torch_xla._XLAC._xla_collective_permute(value, token, pairs)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
+ return result[0]
+
+
+def collective_broadcast(tensors: List[torch.Tensor],
+ root_ordinal: int = 0,
+ groups: Optional[List[int]] = None,
+ pin_layout: bool = True) -> None:
+ """Broadcast values of `tensors` from root replica to other replicas in-place.
+
+ Args:
+ tensors (list): List of `torch.Tensor`s to broadcast.
+ root_ordinal (int): Ordinal of replica with values to broadcast.
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+ """
+ with torch.no_grad():
+ # We must produce the exact same graph in each replica to prevent hanging,
+ # so each replica must have the same multiply op with the same parameters.
+ for tensor in tensors:
+ scale = torch.tensor(
+ 1 if get_ordinal() == root_ordinal else 0, dtype=tensor.dtype)
+ # Transfer scale tensor as device data instead of constant 1 or 0.
+ xscale = send_cpu_data_to_device(scale, tensor.device)
+ tensor.mul_(xscale[0])
+
+ all_reduce(REDUCE_SUM, tensors, groups=groups, pin_layout=pin_layout)
+
+
+def send(value, channel_id):
+ """Performs a XLA `Send()` operation on the input tensor.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#send
+
+ Args:
+ value (torch.Tensor): The input tensor.
+ channel_id (int64): opaque id identifying the destination of the send op.
+ """
+ token, devctx = _get_all_reduce_token()
+ # The input will be returned as result.
+ input_as_result, new_token = torch_xla._XLAC._xla_send(
+ value, token, channel_id)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return input_as_result
+
+
+def recv(output, channel_id):
+ """Performs a XLA `Send()` operation on the input tensor.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#recv
+
+ Args:
+ output (torch.Tensor): The output tensor.
+ channel_id (int64): opaque id identifying the source of the recv op.
+ """
+ token, devctx = _get_all_reduce_token()
+ result, new_token = torch_xla._XLAC._xla_recv(output, token, channel_id)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return result
+
+
+def reduce_scatter(reduce_type,
+ input,
+ scale,
+ scatter_dim,
+ shard_count,
+ groups=None,
+ output=None,
+ pin_layout=True):
+ """Performs a XLA `ReduceScatter()` operation on the input tensor.
+
+ See: https://www.tensorflow.org/xla/operation_semantics#reducescatter
+
+ Args:
+ reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
+ ``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
+ ``xm.REDUCE_MAX``.
+ input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
+ it will also be the output.
+ scale (float): A default scaling value to be applied after the reduce.
+ scatter_dim (int): Dimension number to which apply scatter operation.
+ shard_count (int): The number of ways to split up the scatter_dim in.
+ groups (list): A list of list, representing the replica groups for
+ the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ output: Optional output tensor if `input` is a torch.Tensor, or a list of
+ torch.Tensor if `input` is a list of torch.Tensor.
+ pin_layout (bool, optional): whether to pin the layout for this communication op.
+ Layout pining can prevent potential data corruption when each process that
+ participate in the communication has slightly different program, but it might
+ cause some xla compilation to fail. Unpin the layout when you see error message
+ like "HloModule has a mix of layout constrained".
+
+ Returns:
+ A `torch.Tensor` with all the values reduced across replicas. Each process
+ gets a shard split along the `scatter_dim`. All other dimensions are
+ the same as the input.
+ """
+ token, devctx = _get_all_reduce_token()
+
+ if isinstance(input, torch.Tensor):
+ if output != None:
+ # Call the out of place version of the reduce_scatter
+ new_token = torch_xla._XLAC._xla_reduce_scatter_out(
+ reduce_type, output, input, token, scale, scatter_dim, shard_count,
+ groups or [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return output
+
+ result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
+ scale, scatter_dim,
+ shard_count, groups or [],
+ pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
+ return result[0]
+
+ # Now the input should be a list of Tensors.
+ elif isinstance(input, list) and all(
+ isinstance(v, torch.Tensor) for v in input):
+ if output != None:
+ if not isinstance(output, list) or any(
+ not isinstance(v, torch.Tensor) for v in output):
+ raise TypeError(
+ f"`output` needs to be a list of Tensors, but given {type(output)}."
+ )
+ if len(output) != len(input):
+ raise ValueError("`output` length doesn't match `input` length: "
+ f"{len(output)} vs {len(input)}.")
+ # Call the out of place version of the reduce_scatter
+ new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out(
+ reduce_type, output, input, token, scale, scatter_dim, shard_count,
+ groups or [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
+ return output
+
+ result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
+ reduce_type, input, token, scale, scatter_dim, shard_count, groups or
+ [], pin_layout)
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
+ return result[:-1]
+ else:
+ raise TypeError("`input` needs to be a Tensor or a list of Tensors, but "
+ f"given {type(input)}.")
+
+
+[docs]def add_step_closure(closure, args=(), run_async=False):
+ """Adds a closure to the list of the ones to be run at the end of the step.
+
+ Many times during model training there is the need to print/report (print to
+ console, post to tensorboard, etc...) information which require the content of
+ intermediary tensors to be inspected.
+ Inspecting different tensors content in different points of the model code
+ requires many executions and typically causes performance issues.
+ Adding a step closure will ensure that it will be run after the barrier, when
+ all the live tensors will be already materialized to device data.
+ Live tensors which will include the ones captured by the closure arguments.
+ So using `add_step_closure()` will ensure a single execution will be
+ performed, even when multiple closures are queued, requiring multiple tensors
+ to be inspected.
+ Step closures will be run sequentially in the order they have been queued.
+ Note that even though using this API the execution will be optimized, it is
+ advised to throttle the printing/reporting events once every N steps.
+
+ Args:
+ closure (callable): The function to be called.
+ args (tuple): The arguments to be passed to the closure.
+ run_async: If True, run the closure asynchronously.
+ """
+ devctx = _get_device_context()
+ closures_type = 'async_step_closures' if run_async else 'step_closures'
+ step_closures = getattr(devctx, closures_type, None)
+ if step_closures is None:
+ step_closures = []
+ setattr(devctx, closures_type, step_closures)
+ step_closures.append(lambda a=args: closure(*a))
+
+
+def _run_step_closures():
+ devctx = _get_device_context()
+ async_step_closures = getattr(devctx, 'async_step_closures', None)
+ if async_step_closures is not None:
+ devctx.async_step_closures = []
+ async_closure_handler = getattr(devctx, 'async_closure_handler', None)
+ if async_closure_handler is None:
+ async_closure_handler = xc.AsyncClosureHandler()
+ devctx.async_closure_handler = async_closure_handler
+ async_closure_handler.run_all(async_step_closures)
+
+ step_closures = getattr(devctx, 'step_closures', None)
+ if step_closures is not None:
+ devctx.step_closures = []
+ for closure in step_closures:
+ closure()
+ return devctx
+
+
+def mark_step(wait=False):
+ if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
+ print(
+ 'torch_xla.core.xla_model::mark_step\n',
+ end='',
+ file=sys.stderr,
+ flush=True)
+ torch_xla._XLAC._xla_step_marker(
+ torch_xla._XLAC._xla_get_default_device(), [],
+ wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait))
+ # Only emit metrics from the first local device index, to avoid emitting the
+ # same values from different threads.
+ if is_master_ordinal():
+ ms.save_metrics()
+ devctx = _run_step_closures()
+ torch_xla._XLAC._set_all_reduce_token(devctx.device, None)
+
+
+[docs]def get_stablehlo(tensors=None) -> str:
+ """Get StableHLO for the computation graph in string format.
+
+ If `tensors` is not empty, the graph with `tensors` as outputs will be dump.
+ If `tensors` is empty, the whole computation graph will be dump.
+ TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be
+ dump as outputs. Need further investigation.
+
+ For inference graph, it is recommended to pass the model outputs to `tensors`.
+ For training graph, it is not straightforward to identify the "outputs". Using empty `tensors` is recommended.
+
+ To enable source line info in StableHLO, please set env var XLA_HLO_DEBUG=1.
+
+ Args:
+ tensors (list[torch.Tensor], optional): Tensors that represent the output/root of the StableHLO graph.
+
+ Returns:
+ StableHLO Module in string format.
+ """
+ if tensors is None:
+ tensors = []
+ return torch_xla._XLAC._get_stablehlo(
+ tensors, torch_xla._XLAC._xla_get_default_device(), [],
+ False).decode('utf-8')
+
+
+[docs]def get_stablehlo_bytecode(tensors=None) -> bytes:
+ """Get StableHLO for the computation graph in bytecode format.
+
+ If `tensors` is not empty, the graph with `tensors` as outputs will be dump.
+ If `tensors` is empty, the whole computation graph will be dump.
+ TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be
+ dump as outputs. Need further investigation.
+
+ For inference graph, it is recommended to pass the model outputs to `tensors`.
+ For training graph, it is not straightforward to identify the "outputs". Using empty `tensors` is recommended.
+
+ Args:
+ tensors (list[torch.Tensor], optional): Tensors that represent the output/root of the StableHLO graph.
+
+ Returns:
+ StableHLO Module in bytecode format.
+ """
+ if tensors is None:
+ tensors = []
+ return torch_xla._XLAC._get_stablehlo(
+ tensors, torch_xla._XLAC._xla_get_default_device(), [], True)
+
+
+[docs]def wait_device_ops(devices=[]):
+ """Waits for all the async operations on the given devices to complete.
+
+ Args:
+ devices (string..., optional): The devices whose async ops need to be waited
+ for. If empty, all the local devices will be waited for.
+ """
+ torch_xla._XLAC._xla_wait_device_ops(devices=devices)
+
+
+def reduce_gradients(optimizer, groups=None, pin_layout=True):
+ """Reduces all the gradients handled by an optimizer.
+
+ Args:
+ optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance
+ containing the gradients to be reduced.
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout when reducing gradients.
+ See `xm.all_reduce` for details.
+ """
+ count = xrt_world_size()
+ if count > 1:
+ gradients = _fetch_gradients(optimizer)
+ all_reduce(
+ REDUCE_SUM,
+ gradients,
+ scale=1.0 / count,
+ groups=groups,
+ pin_layout=pin_layout)
+
+
+[docs]def optimizer_step(optimizer,
+ barrier=False,
+ optimizer_args={},
+ groups=None,
+ pin_layout=True):
+ """Run the provided optimizer step and issue the XLA device step computation.
+
+ Args:
+ optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance whose
+ `step()` function needs to be called. The `step()` function will be called
+ with the `optimizer_args` named arguments.
+ barrier (bool, optional): Whether the XLA tensor barrier should be issued in
+ this API. If using the PyTorch XLA `ParallelLoader` or `DataParallel`
+ support, this is not necessary as the barrier will be issued by the XLA
+ data loader iterator `next()` call.
+ Default: False
+ optimizer_args (dict, optional): Named arguments dictionary for the
+ `optimizer.step()` call.
+ groups (list, optional): A list of list, representing the replica groups for
+ the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
+ defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
+ the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
+ all the replicas in it.
+ pin_layout (bool, optional): whether to pin the layout when reducing gradients.
+ See `xm.all_reduce` for details.
+
+ Returns:
+ The same value returned by the `optimizer.step()` call.
+ """
+ reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout)
+ loss = optimizer.step(**optimizer_args)
+ if barrier:
+ mark_step()
+ return loss
+
+
+[docs]def save(data, file_or_path, master_only=True, global_master=False):
+ """Saves the input data into a file.
+
+ The saved data is transferred to PyTorch CPU device before being saved, so a
+ following `torch.load()` will load CPU data.
+ Care must be taken when working with views. Instead of saving views it's
+ recommended that you recreate them after the tensors have been loaded and
+ moved to their destination device(s).
+
+ Args:
+ data: The input data to be saved. Any nested combination of Python objects
+ (list, tuples, sets, dicts, ...).
+ file_or_path: The destination for the data saving operation. Either a file
+ path or a Python file object. If `master_only` is ``False`` the path or
+ file objects must point to different destinations as otherwise all the
+ writes from the same host will override each other.
+ master_only (bool, optional): Whether only the master device should save the
+ data. If False, the `file_or_path` argument should be a different file or
+ path for each of the ordinals taking part to the replication, otherwise
+ all the replicas on the same host will be writing to the same location.
+ Default: True
+ global_master (bool, optional): When ``master_only`` is ``True`` this flag
+ controls whether every host's master (if ``global_master`` is ``False``)
+ saves the content, or only the global master (ordinal 0).
+ Default: False
+ sync (bool, optional): Whether to synchronize all replicas after saving
+ tensors. If True, all replicas must call `xm.save` or the main process
+ will hang.
+ """
+ should_write_data = not master_only or is_master_ordinal(
+ local=not global_master)
+
+ cpu_data = _maybe_convert_to_cpu(data, convert=should_write_data)
+ if should_write_data:
+ torch.save(cpu_data, file_or_path)
+
+
+def _maybe_convert_to_cpu(data, convert=True):
+
+ def convert_fn(tensors):
+ torch_xla._XLAC._xla_sync_multi(
+ tensors, devices=[], wait=True, sync_xla_data=True)
+ if not convert:
+ return tensors
+ return torch_xla._XLAC._xla_get_cpu_tensors(tensors)
+
+ def select_fn(v):
+ return type(v) == torch.Tensor and is_xla_tensor(v)
+
+ return ToXlaTensorArena(convert_fn, select_fn).transform(data)
+
+
+def send_cpu_data_to_device(datas, device, input_sharding=None):
+
+ def convert_fn(tensors):
+ devices = [str(device)] * len(tensors)
+ shardings = None
+ if input_sharding:
+ shardings = [input_sharding.xla_spec(t) for t in tensors]
+ xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
+ shardings)
+ return xtensors
+
+ def select_fn(v):
+ return type(v) == torch.Tensor and v.device.type == 'cpu'
+
+ if type(datas) is torch.Tensor:
+ datas = [datas]
+ return ToXlaTensorArena(convert_fn, select_fn).transform(datas)
+
+
+def xla_rendezvous(payload: bytes = b'',
+ ordinals: Optional[List[int]] = None,
+ tag: Optional[str] = None) -> List[bytes]:
+ """Share `payload` with all replicas in `ordinals`.
+
+ `tag` is ignored except for logging.
+
+ Uses XLA collective communication to communicate between replicas, so this
+ will sync the graph (`xm.mark_step`).
+
+ Args:
+ tag: Name of this rendezvous operation.
+ payload: Payload to share with other replicas.
+ ordinals: List of replicas participating in rendezvous.
+ Returns:
+ List of bytes from other replicas.
+ """
+ if ordinals and len(ordinals) != runtime.global_device_count():
+ raise ValueError('Only global rendezvous is supported')
+
+ if not isinstance(payload, bytes):
+ raise TypeError('`payload` must be bytes, not {}'.format(type(payload)))
+
+ # Finish all execution of previous graphs to avoid recompilation
+ mark_step()
+
+ device = xla_device()
+
+ data = torch.tensor(list(payload), dtype=torch.uint8)
+ size = torch.tensor([data.shape[0]], dtype=torch.int, device=device)
+
+ if tag:
+ logging.info(f"Joining rendezvous '{tag}'...")
+
+ sizes = all_gather(size)
+
+ max_size = torch.max(sizes)
+ mark_step()
+
+ # If all payloads are empty, return immediately to avoid more CPU transfers
+ if max_size.item() < 1:
+ return [b'' for _ in range(sizes.size()[0])]
+
+ padded_data = torch.nn.functional.pad(data, (
+ 0,
+ max_size.item() - size.item(),
+ )).to(xla_device())
+ raw_data = all_gather(padded_data)
+ data_list = torch.split(raw_data, max_size)
+
+ payloads = [d[:sz] for d, sz in zip(data_list, sizes.cpu())]
+ mark_step()
+
+ return [bytes(p.cpu().tolist()) for p in payloads]
+
+
+[docs]def rendezvous(tag, payload=b'', replicas=[]):
+ """Waits for all the mesh clients to reach the named rendezvous.
+
+ Note: PJRT does not support the XRT mesh server, so this is effectively an
+ alias to `xla_rendezvous`.
+
+ Args:
+ tag (string): The name of the rendezvous to join.
+ payload (bytes, optional): The payload to be sent to the rendezvous.
+ replicas (list, int): The replica ordinals taking part of the rendezvous.
+ Empty means all replicas in the mesh.
+ Default: []
+
+ Returns:
+ The payloads exchanged by all the other cores, with the payload of core
+ ordinal `i` at position `i` in the returned tuple.
+ """
+ return xla_rendezvous(payload, replicas or None, tag=tag)
+
+
+[docs]def do_on_ordinals(target, data=(), ordinals=(0,)):
+ """Runs a function only on a given set of ordinals.
+
+ Args:
+ target (callable): The function to be run on `ordinals`.
+ data: Any input data for the `target` function which contains tensors. All
+ the XLA tensors used by the `target` function must be passed in this
+ argument. Every other data used by the function can be captured by the
+ Python interpreter as usual.
+ Default: ()
+ ordinals (list, int): The list/set of ordinals where the `target` function
+ should run.
+ Default: (0,)
+
+ Returns:
+ In the ordinals that ran the `target` function, the function return value,
+ otherwise `None`.
+ """
+ running = get_ordinal() in ordinals
+ cpu_data = _maybe_convert_to_cpu(data, convert=running)
+ if running:
+ result = target(*cpu_data)
+ else:
+ result = None
+ rendezvous('torch_xla.core.xla_model.do_on_ordinals')
+ return result
+
+
+[docs]def mesh_reduce(tag, data, reduce_fn):
+ """Performs an out-of-graph client mesh reduction.
+
+ Args:
+ tag (string): The name of the rendezvous to join.
+ data: The data to be reduced. The `reduce_fn` callable will receive a list
+ with the copies of the same data coming from all the mesh client processes
+ (one per core).
+ reduce_fn (callable): A function which receives a list of `data`-like
+ objects and returns the reduced result.
+
+ Returns:
+ The reduced value.
+ """
+ cpu_data = _maybe_convert_to_cpu(data)
+ bio = io.BytesIO()
+ torch.save(cpu_data, bio)
+ xdata = rendezvous(tag, bio.getvalue())
+ xldata = []
+ for xd in xdata:
+ xbio = io.BytesIO(xd)
+ xldata.append(torch.load(xbio))
+ return reduce_fn(xldata) if xldata else cpu_data
+
+
+[docs]def set_rng_state(seed, device=None):
+ """Sets the random number generator state.
+
+ Args:
+ seed (integer): The state to be set.
+ device (string, optional): The device where the RNG state needs to be set.
+ If missing the default device seed will be set.
+ """
+ if device is None:
+ device = torch_xla._XLAC._xla_get_default_device()
+ torch_xla._XLAC._xla_set_rng_seed(seed, str(device) if device else '')
+
+
+[docs]def get_rng_state(device=None):
+ """Gets the current running random number generator state.
+
+ Args:
+ device (string, optional): The device whose RNG state needs to be retrieved.
+ If missing the default device seed will be set.
+
+ Returns:
+ The RNG state, as integer.
+ """
+ if device is None:
+ device = torch_xla._XLAC._xla_get_default_device()
+ return torch_xla._XLAC._xla_get_rng_seed(str(device) if device else '')
+
+
+[docs]def get_memory_info(device):
+ """Retrieves the device memory information.
+
+ Args:
+ device (string): The device whose memory information are requested.
+
+ Returns:
+ A dictionary with `kb_free` (free memory in KB) and `kb_total` (total
+ memory in KB) keys.
+ """
+ return torch_xla._XLAC._xla_memory_info(str(device))
+
+
+def optimization_barrier_(tensors):
+ """Blocks xla compiler from moving computations across this barrier. The common
+ use case would be blocking xla common-subexpression elimination pass from undoing
+ the gradient checkpointing.
+
+ Args:
+ tensors (List[torch.Tensor]): List of `torch.Tensor` to add barrier to.
+ """
+ torch_xla._XLAC._xla_optimization_barrier_(tensors)
+
+
+def broadcast_master_param(model: torch.nn.Module) -> None:
+ """
+ Broadcast the model parameters from master process to other processes
+ """
+ parameters_and_buffers = list(
+ itertools.chain(model.parameters(), model.buffers()))
+ collective_broadcast(parameters_and_buffers)
+ mark_step()
+
+import itertools
+import threading
+import torch
+import torch_xla
+import torch_xla.debug.profiler as xp
+import torch_xla.utils.keyd_queue as kq
+import torch_xla.utils.utils as xu
+import torch_xla.core.xla_model as xm
+
+
+class PerDeviceQueue(object):
+
+ def __init__(self, device, loader_prefetch_size, device_prefetch_size):
+ self.device = device
+ self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
+ self.queue = kq.Queue(maxsize=device_prefetch_size)
+ self.close_queue_count = itertools.count()
+
+
+class PerDeviceLoader(object):
+
+ def __init__(self, loader, device):
+ self._loader = loader
+ self._device = device
+ self._mark_step_batch_count = loader.batches_per_execution - 1
+ self._batches_yielded = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self.next()
+
+ def __len__(self):
+ return self._loader.per_device_samples()
+
+ def next(self):
+ if xp.get_tracer_marked_step():
+ xp.set_tracer_marked_step(False)
+ self._batches_yielded += 1
+ else:
+ if self._mark_step_batch_count <= self._batches_yielded:
+ self._batches_yielded = 0
+ xm.mark_step()
+ else:
+ self._batches_yielded += 1
+
+ item = self._loader.next_item(self._device)
+ if item is None:
+ xm.mark_step()
+ raise StopIteration
+ return item
+
+
+[docs]class ParallelLoader(object):
+ """Wraps an existing PyTorch DataLoader with background data upload.
+
+ Args:
+ loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
+ wrapped.
+ devices (`torch.device`...): The list of devices where the data has to be
+ sent. The i-th sample returned by the `loader` will be sent to `devices[i
+ % len(devices)]`.
+ batchdim (int, optional): The dimension which is holding the batch size.
+ Default: 0
+ loader_prefetch_size (int, optional): The max capacity of the queue used by
+ the thread which is reading samples from the `loader`, to be processed by
+ the worker threads which upload data to the devices.
+ Default: 8
+ device_prefetch_size (int, optional): The max size of the per-device queues,
+ where the worker threads deposit tensors which have already been sent to
+ devices.
+ Default: 4
+ host_to_device_transfer_threads (int, optional): The number of threads that
+ work in parallel to transfer data from loader queue to device queue.
+ Default: 1
+ input_sharding (ShardingSpec, optional): Sharding spec to apply to
+ compatible input tensors after loading.
+ Default: None
+ """
+
+ def __init__(self,
+ loader,
+ devices,
+ batchdim=0,
+ batches_per_execution=1,
+ loader_prefetch_size=8,
+ device_prefetch_size=4,
+ host_to_device_transfer_threads=1,
+ input_sharding=None):
+ self._loader = loader
+ self._devices = [torch.device(x) for x in devices]
+ self._batchdim = batchdim
+ self._batches_per_execution = batches_per_execution
+ self._done = False
+ self._queues = dict()
+ self._input_sharding = input_sharding
+ for device in self._devices:
+ self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
+ device_prefetch_size)
+ thread = threading.Thread(target=self._loader_worker)
+ thread.daemon = True
+ thread.start()
+ for dqueue in self._queues.values():
+ for i in range(host_to_device_transfer_threads):
+ thread = threading.Thread(
+ target=self._worker,
+ args=(
+ dqueue,
+ host_to_device_transfer_threads,
+ ))
+ thread.daemon = True
+ thread.start()
+
+[docs] def per_device_loader(self, device):
+ """Retrieves the loader iterator object for the given device.
+
+ Args:
+ device (`torch.device`): The device whole loader is being requested.
+
+ Returns:
+ The loader iterator object for the `device`. This is not a
+ `torch.utils.data.DataLoader` interface, but a Python iterator which
+ returns the same tensor data structure as returned by the wrapped
+ `torch.utils.data.DataLoader`, but residing on XLA devices.
+ """
+ return PerDeviceLoader(self, torch.device(device))
+
+ def per_device_samples(self):
+ return len(self._loader) // len(self._devices)
+
+ def next_item(self, device):
+ dqueue = self._queues[device]
+ return dqueue.queue.get()
+
+ def close(self):
+ self._done = True
+ for dqueue in self._queues.values():
+ dqueue.queue.close()
+ dqueue.loader_queue.close()
+
+ @property
+ def batches_per_execution(self):
+ return self._batches_per_execution
+
+ def _loader_worker(self):
+ queues = list(self._queues.values())
+ data_iter = enumerate(self._loader)
+ batch = []
+ while not self._done:
+ try:
+ _, data = next(data_iter)
+ except StopIteration:
+ break
+ batch.append(data)
+ if len(batch) == len(self._devices):
+ for queue_no, device_batch in enumerate(batch):
+ queues[queue_no].loader_queue.put(device_batch)
+ batch = []
+ for dqueue in queues:
+ dqueue.loader_queue.close_write()
+
+ def _get_batch(self, dqueue):
+ batch = []
+ while dqueue.queue.max_size() > len(batch):
+ item = dqueue.loader_queue.get()
+ if item is None:
+ break
+ batch.append(item)
+ return batch
+
+ def _worker(self, dqueue, host_to_device_transfer_threads):
+ device = torch.device(dqueue.device)
+ while True:
+ batch = self._get_batch(dqueue)
+ if not batch:
+ break
+ batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding)
+ for data in batch:
+ dqueue.queue.put(data)
+ close_queue_count = next(dqueue.close_queue_count)
+ if close_queue_count == host_to_device_transfer_threads - 1:
+ dqueue.queue.close_write()
+
+
+class MpDeviceLoader(object):
+ """Wraps an existing PyTorch DataLoader with background data upload.
+
+ This class should only be using with multi-processing data parallelism.
+
+ Args:
+ loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
+ wrapped.
+ device (`torch.device`...): The device where the data has to be sent.
+ kwargs: Named arguments for the `ParallelLoader` constructor.
+ """
+
+ def __init__(self, loader, device, **kwargs):
+ self._loader = loader
+ self._device = device
+ self._parallel_loader_kwargs = kwargs
+
+ def __iter__(self):
+ parallel_loader = ParallelLoader(self._loader, [self._device],
+ **self._parallel_loader_kwargs)
+ return parallel_loader.per_device_loader(self._device)
+
+ def __len__(self):
+ return len(self._loader)
+
+import torch.multiprocessing
+from torch_xla import runtime as xr
+from torch_xla._internal import pjrt
+
+
+[docs]@xr.requires_pjrt
+def spawn(fn,
+ args=(),
+ nprocs=None,
+ join=True,
+ daemon=False,
+ start_method='spawn'):
+ """Enables multi processing based replication.
+
+ Args:
+ fn (callable): The function to be called for each device which takes part of
+ the replication. The function will be called with a first argument being
+ the global index of the process within the replication, followed by the
+ arguments passed in `args`.
+ args (tuple): The arguments for `fn`.
+ Default: Empty tuple
+ nprocs (int): The number of processes/devices for the replication. At the
+ moment, if specified, can be either 1 or the maximum number of devices.
+ join (bool): Whether the call should block waiting for the completion of the
+ processes which have being spawned.
+ Default: True
+ daemon (bool): Whether the processes being spawned should have the `daemon`
+ flag set (see Python multi-processing API).
+ Default: False
+ start_method (string): The Python `multiprocessing` process creation method.
+ Default: `spawn`
+
+ Returns:
+ The same object returned by the `torch.multiprocessing.spawn` API. If
+ `nprocs` is 1 the `fn` function will be called directly, and the API will
+ return None.
+ """
+ return pjrt.spawn(fn, nprocs, start_method, args)
+
+
+[docs]class MpModelWrapper(object):
+ """Wraps a model to minimize host memory usage when `fork` method is used.
+
+ This class should be used together with the `spawn(..., start_method='fork')`
+ API to minimize the use of host memory.
+ Instead of creating models on each multiprocessing process, hence replicating
+ the model's initial host memory, the model is created once at global scope,
+ and then moved into each device inside the `spawn()` target function.
+ Example::
+
+ WRAPPED_MODEL = xmp.MpModelWrapper(MyNetwork())
+
+ def _mp_fn(index, ...):
+ device = xm.xla_device()
+ model = WRAPPED_MODEL.to(device)
+ ...
+
+ xmp.spawn(_mp_fn, ..., start_method='fork')
+
+ This method has two advantages. First it uses only one copy of the memory
+ pages to host the original model weights, and second it serializes the move
+ of the wrapped model into each device, by lowering the load onto the system
+ memory during the process.
+ """
+
+ def __init__(self, model):
+ """Creates a new `MpModelWrapper` object.
+
+ Args:
+ model (torch.nn.Module): The model to be wrapped. Should be on PyTorch CPU
+ device (which is the default when creating new models).
+ """
+ self._model = model
+ self._lock = torch.multiprocessing.Lock()
+
+[docs] def to(self, device):
+ """Retrieves the model moved onto the specified device.
+
+ Args:
+ device (torch.device): The device where the model should be moved onto.
+ Returns:
+ The model on the specified device.
+ """
+ with self._lock:
+ self._model.to(device)
+ return self._model
+
+
+[docs]class MpSerialExecutor(object):
+ """Utility to run a function in a serialized fashion among multi-core processes.
+
+ Example::
+
+ # At global scope.
+ SERIAL_EXEC = xmp.MpSerialExecutor()
+
+ def load_dataset(path):
+ return maybe_download_and_load(path)
+
+ def _mp_fn(index, ...):
+ # Avoid all cores downloading the same data with the serial executor.
+ dataset = SERIAL_EXEC.run(lambda: load_dataset('/tmp/mnist-data'))
+ ...
+
+ xmp.spawn(_mp_fn, ...)
+ """
+
+ def __init__(self):
+ self._lock = torch.multiprocessing.Lock()
+
+[docs] def run(self, fn):
+ """Runs the provided function serialized WRT each per-core process.
+
+ Args:
+ fn (callable): The function to run in a serialized fashion.
+ Returns:
+ The `fn` return value.
+ """
+ with self._lock:
+ return fn()
+
+import os
+import shutil
+
+import torch
+import torch_xla
+import torch_xla.utils.utils as xu
+import torch_xla.core.xla_model as xm
+
+
+class TensorReference(object):
+
+ def __init__(self, tid):
+ self.tid = tid
+
+
+def _get_tensors_folder(path):
+ return path + '.tensors'
+
+
+def _get_tensor_file(path, tid):
+ return os.path.join(path, 'tensor_{}.pt'.format(tid))
+
+
+def _rewrite_data(path, data, save_tensors):
+
+ def convert_fn(tensors):
+ torch_xla._XLAC._xla_sync_multi(
+ tensors, devices=[], wait=True, sync_xla_data=True)
+ rewritten_tensors = []
+ for i, t in enumerate(tensors):
+ if save_tensors:
+ torch.save(t.cpu(), _get_tensor_file(path, i))
+ rewritten_tensors.append(TensorReference(i))
+ return rewritten_tensors
+
+ def select_fn(v):
+ return type(v) == torch.Tensor and xm.is_xla_tensor(v)
+
+ if save_tensors:
+ if os.path.isdir(path):
+ shutil.rmtree(path)
+ os.mkdir(path)
+ return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
+
+
+[docs]def save(data, path, master_only=True, global_master=False):
+ """Saves the input data into a file.
+
+ The saved data is transferred to PyTorch CPU device before being saved, so a
+ following `torch.load()` will load CPU data.
+ Care must be taken when working with views. Instead of saving views it's
+ recommended that you recreate them after the tensors have been loaded and
+ moved to their destination device(s).
+
+ Args:
+ data: The input data to be saved. Any nested combination of Python objects
+ (list, tuples, sets, dicts, ...).
+ path: The destination file for the data saving operation. If `master_only`
+ is ``False`` the path must point to different destinations as otherwise
+ all the writes from the same host will override each other.
+ master_only (bool, optional): Whether only the master device should save the
+ data. If False, the `path` argument should be a different path for each of
+ the ordinals taking part to the replication, otherwise all the replicas on
+ the same host will be writing to the same location.
+ Default: True
+ global_master (bool, optional): When ``master_only`` is ``True`` this flag
+ controls whether every host's master (if ``global_master`` is ``False``)
+ saves the content, or only the global master (ordinal 0).
+ Default: False
+ """
+ should_write_data = not master_only or xm.is_master_ordinal(
+ local=not global_master)
+
+ ref_data = _rewrite_data(_get_tensors_folder(path), data, should_write_data)
+ if should_write_data:
+ torch.save(ref_data, path)
+
+
+[docs]def load(path):
+ """Loads data previously saved with the `save()` API.
+
+ Args:
+ path (str): The path passed to the `save()` API.
+ Returns:
+ The loaded data.
+ """
+ ref_data = torch.load(path)
+ tensor_folder = _get_tensors_folder(path)
+
+ def convert_fn(tensors):
+ rewritten_tensors = []
+ for t in tensors:
+ rewritten_tensors.append(
+ torch.load(_get_tensor_file(tensor_folder, t.tid)))
+ return rewritten_tensors
+
+ def select_fn(v):
+ return type(v) == TensorReference
+
+ return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data)
+
+from concurrent import futures
+import contextlib
+import copy
+import os
+import shutil
+import socket
+import sys
+import tempfile
+import time
+
+
+class Cleaner(object):
+
+ def __init__(self, func):
+ self.func = func
+
+ def __del__(self):
+ self.func()
+
+
+class LazyProperty(object):
+
+ def __init__(self, gen_fn):
+ self._gen_fn = gen_fn
+
+ @property
+ def value(self):
+ if self._gen_fn is not None:
+ self._value = self._gen_fn()
+ self._gen_fn = None
+ return self._value
+
+
+class TmpFolder(object):
+
+ def __init__(self):
+ self.name = tempfile.mkdtemp()
+ self.cleaner = Cleaner(lambda: shutil.rmtree(self.name))
+
+
+[docs]class SampleGenerator(object):
+ """Iterator which returns multiple samples of a given input data.
+
+ Can be used in place of a PyTorch `DataLoader` to generate synthetic data.
+
+ Args:
+ data: The data which should be returned at each iterator step.
+ sample_count: The maximum number of `data` samples to be returned.
+ """
+
+ def __init__(self, data, sample_count):
+ self._data = data
+ self._sample_count = sample_count
+ self._count = 0
+
+ def __iter__(self):
+ return SampleGenerator(self._data, self._sample_count)
+
+ def __len__(self):
+ return self._sample_count
+
+ def __next__(self):
+ return self.next()
+
+ def next(self):
+ if self._count >= self._sample_count:
+ raise StopIteration
+ self._count += 1
+ return self._data
+
+
+class FnDataGenerator(object):
+
+ def __init__(self, func, batch_size, gen_tensor, dims=None, count=1):
+ self._func = func
+ self._batch_size = batch_size
+ self._gen_tensor = gen_tensor
+ self._dims = list(dims) if dims else [1]
+ self._count = count
+ self._emitted = 0
+
+ def __len__(self):
+ return self._count
+
+ def __iter__(self):
+ return FnDataGenerator(
+ self._func,
+ self._batch_size,
+ self._gen_tensor,
+ dims=self._dims,
+ count=self._count)
+
+ def __next__(self):
+ return self.next()
+
+ def next(self):
+ if self._emitted >= self._count:
+ raise StopIteration
+ data = self._gen_tensor(self._batch_size, *self._dims)
+ target = self._func(data)
+ self._emitted += 1
+ return data, target
+
+
+[docs]class DataWrapper(object):
+ """Utility class to wrap data structures to be sent to device."""
+
+ def __init__(self):
+ pass
+
+ def get_tensors(self):
+ """Returns the list of CPU tensors which must be sent to device."""
+ raise NotImplementedError('The method is missing an implementation')
+
+ def from_tensors(self, tensors):
+ """Build an instance of the wrapped object given the input tensors.
+
+ The number of tensors is the same as the ones returned by the
+ `get_tensors()` API, and `tensors[i]` is the device copy of
+ `get_tensors()[i]`.
+
+ Returns:
+ The unwrapped instance of the object with tensors on device.
+ """
+ raise NotImplementedError('The method is missing an implementation')
+
+
+def as_list(t):
+ return t if isinstance(t, (tuple, list)) else [t]
+
+
+def getenv_as(name, type, defval=None):
+ env = os.environ.get(name, None)
+ if type == bool:
+ return defval if env is None else type(int(env))
+ return defval if env is None else type(env)
+
+
+def _for_each_instance(value, select_fn, fn, seen):
+ if id(value) in seen:
+ return
+ seen.add(id(value))
+ if select_fn(value):
+ fn(value)
+ elif isinstance(value, dict):
+ for k, v in value.items():
+ _for_each_instance(k, select_fn, fn, seen)
+ _for_each_instance(v, select_fn, fn, seen)
+ elif isinstance(value, (list, tuple, set)):
+ for x in value:
+ _for_each_instance(x, select_fn, fn, seen)
+ elif isinstance(value, DataWrapper):
+ for x in value.get_tensors():
+ _for_each_instance(x, select_fn, fn, seen)
+ elif hasattr(value, '__dict__'):
+ for k in value.__dict__.keys():
+ _for_each_instance(value.__dict__[k], select_fn, fn, seen)
+
+
+def for_each_instance(value, select_fn, fn):
+ seen = set()
+ _for_each_instance(value, select_fn, fn, seen)
+
+
+def _for_each_instance_rewrite(value, select_fn, fn, rwmap):
+ rvalue = rwmap.get(id(value), None)
+ if rvalue is not None:
+ return rvalue
+ result = value
+ if select_fn(value):
+ result = fn(value)
+ rwmap[id(value)] = result
+ elif isinstance(value, dict):
+ result = dict()
+ rwmap[id(value)] = result
+ for k, v in value.items():
+ k = _for_each_instance_rewrite(k, select_fn, fn, rwmap)
+ result[k] = _for_each_instance_rewrite(v, select_fn, fn, rwmap)
+ elif isinstance(value, set):
+ result = set()
+ rwmap[id(value)] = result
+ for x in value:
+ result.add(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
+ elif isinstance(value, (list, tuple)):
+ # We transform tuples to lists here, as we need to set the object mapping
+ # before calling into the recursion. This code might break if user code
+ # expects a tuple.
+ result = list()
+ rwmap[id(value)] = result
+ for x in value:
+ result.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
+ elif isinstance(value, DataWrapper):
+ new_tensors = []
+ for x in value.get_tensors():
+ new_tensors.append(_for_each_instance_rewrite(x, select_fn, fn, rwmap))
+ result = value.from_tensors(new_tensors)
+ rwmap[id(value)] = result
+ elif hasattr(value, '__dict__'):
+ result = copy.copy(value)
+ rwmap[id(value)] = result
+ for k in result.__dict__.keys():
+ v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap)
+ result.__dict__[k] = v
+ else:
+ rwmap[id(value)] = result
+ return result
+
+
+def for_each_instance_rewrite(value, select_fn, fn):
+ rwmap = dict()
+ return _for_each_instance_rewrite(value, select_fn, fn, rwmap)
+
+
+def shape(inputs):
+ cshape = []
+ if isinstance(inputs, (list, tuple)):
+ lshape = None
+ for input in inputs:
+ ishape = shape(input)
+ if lshape is None:
+ lshape = ishape
+ else:
+ assert lshape == ishape
+ cshape.extend([len(inputs)] + (lshape or []))
+ return cshape
+
+
+def flatten_nested_tuple(inputs):
+ flat = []
+ if isinstance(inputs, (list, tuple)):
+ for input in inputs:
+ flat.extend(flatten_nested_tuple(input))
+ else:
+ flat.append(inputs)
+ return tuple(flat)
+
+
+def list_copy_append(ilist, item):
+ ilist_copy = list(ilist)
+ ilist_copy.append(item)
+ return ilist_copy
+
+
+def null_print(*args, **kwargs):
+ return
+
+
+def eprint(*args, **kwargs):
+ print(*args, file=sys.stderr, **kwargs)
+
+
+def get_print_fn(debug=None):
+ if debug is None:
+ debug = int(os.environ.get('DEBUG', '0'))
+ return eprint if debug else null_print
+
+
+def timed(fn, msg='', printfn=eprint):
+ if printfn is None:
+ printfn = get_print_fn()
+ s = time.time()
+ result = fn()
+ printfn('{}{:.3f}ms'.format(msg, 1000.0 * (time.time() - s)))
+ return result
+
+
+def get_free_tcp_ports(n=1):
+ ports = []
+ for _ in range(0, n):
+ with contextlib.closing(socket.socket(socket.AF_INET,
+ socket.SOCK_STREAM)) as s:
+ s.bind(('', 0))
+ ports.append(s.getsockname()[1])
+ return ports
+
+
+def parallel_work(num_workers, fn, *args):
+ """Executes fn in parallel threads with args and returns result list.
+
+ Args:
+ num_workers: number of workers in thread pool to execute work.
+ fn: python function for each thread to execute.
+ *args: arguments used to call executor.map with.
+
+ Raises:
+ Exception: re-raises any exceptions that may have been raised by workers.
+ """
+ with futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+ results = executor.map(fn, *args)
+ return [res for res in results] # Iterating to re-raise any exceptions
+
+
+class TimedScope(object):
+
+ def __init__(self, msg='', printfn=eprint):
+ if printfn is None:
+ printfn = get_print_fn()
+ self._msg = msg
+ self._printfn = printfn
+ self._error = None
+
+ def __enter__(self):
+ self._start = time.time()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ if self._error is None:
+ self._printfn('{}{:.3f}ms'.format(self._msg,
+ 1000.0 * (time.time() - self._start)))
+
+ def set_error(self, error):
+ self._error = error
+
+
+def check_env_flag(name, default=''):
+ flag = os.getenv(name, default)
+ return flag == 'true' or (flag.isdigit() and int(flag) > 0)
+
' + + '' + + Documentation.gettext("Hide Search Matches") + + "
" + ) + ); + }, + + /** + * helper function to hide the search marks again + */ + hideSearchWords: () => { + document + .querySelectorAll("#searchbox .highlight-link") + .forEach((el) => el.remove()); + document + .querySelectorAll("span.highlighted") + .forEach((el) => el.classList.remove("highlighted")); + const url = new URL(window.location); + url.searchParams.delete("highlight"); + window.history.replaceState({}, "", url); + }, + + /** + * helper function to focus on search bar + */ + focusSearchBar: () => { + document.querySelectorAll("input[name=q]")[0]?.focus(); + }, + + /** + * Initialise the domain index toggle buttons + */ + initDomainIndexTable: () => { + const toggler = (el) => { + const idNumber = el.id.substr(7); + const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`); + if (el.src.substr(-9) === "minus.png") { + el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`; + toggledRows.forEach((el) => (el.style.display = "none")); + } else { + el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`; + toggledRows.forEach((el) => (el.style.display = "")); + } + }; + + const togglerElements = document.querySelectorAll("img.toggler"); + togglerElements.forEach((el) => + el.addEventListener("click", (event) => toggler(event.currentTarget)) + ); + togglerElements.forEach((el) => (el.style.display = "")); + if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler); + }, + + initOnKeyListeners: () => { + // only install a listener if it is really needed + if ( + !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && + !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS + ) + return; + + const blacklistedElements = new Set([ + "TEXTAREA", + "INPUT", + "SELECT", + "BUTTON", + ]); + document.addEventListener("keydown", (event) => { + if (blacklistedElements.has(document.activeElement.tagName)) return; // bail for input elements + if (event.altKey || event.ctrlKey || event.metaKey) return; // bail with special keys + + if (!event.shiftKey) { + switch (event.key) { + case "ArrowLeft": + if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; + + const prevLink = document.querySelector('link[rel="prev"]'); + if (prevLink && prevLink.href) { + window.location.href = prevLink.href; + event.preventDefault(); + } + break; + case "ArrowRight": + if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; + + const nextLink = document.querySelector('link[rel="next"]'); + if (nextLink && nextLink.href) { + window.location.href = nextLink.href; + event.preventDefault(); + } + break; + case "Escape": + if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; + Documentation.hideSearchWords(); + event.preventDefault(); + } + } + + // some keyboard layouts may need Shift to get / + switch (event.key) { + case "/": + if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; + Documentation.focusSearchBar(); + event.preventDefault(); + } + }); + }, +}; + +// quick alias for translations +const _ = Documentation.gettext; + +_ready(Documentation.init); diff --git a/release/2.2/_static/documentation_options.js b/release/2.2/_static/documentation_options.js new file mode 100644 index 00000000000..df431cc682c --- /dev/null +++ b/release/2.2/_static/documentation_options.js @@ -0,0 +1,14 @@ +var DOCUMENTATION_OPTIONS = { + URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), + VERSION: 'master', + LANGUAGE: 'en', + COLLAPSE_INDEX: false, + BUILDER: 'html', + FILE_SUFFIX: '.html', + LINK_SUFFIX: '.html', + HAS_SOURCE: true, + SOURCELINK_SUFFIX: '.txt', + NAVIGATION_WITH_KEYS: false, + SHOW_SEARCH_SUMMARY: true, + ENABLE_SEARCH_SHORTCUTS: false, +}; \ No newline at end of file diff --git a/release/2.2/_static/file.png b/release/2.2/_static/file.png new file mode 100644 index 00000000000..a858a410e4f Binary files /dev/null and b/release/2.2/_static/file.png differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-bold-italic.woff b/release/2.2/_static/fonts/FreightSans/freight-sans-bold-italic.woff new file mode 100644 index 00000000000..e317248423c Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-bold-italic.woff differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-bold-italic.woff2 b/release/2.2/_static/fonts/FreightSans/freight-sans-bold-italic.woff2 new file mode 100644 index 00000000000..cec2dc94fbb Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-bold-italic.woff2 differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-bold.woff b/release/2.2/_static/fonts/FreightSans/freight-sans-bold.woff new file mode 100644 index 00000000000..de46625edfc Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-bold.woff differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-bold.woff2 b/release/2.2/_static/fonts/FreightSans/freight-sans-bold.woff2 new file mode 100644 index 00000000000..dc05cd82bc4 Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-bold.woff2 differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-book-italic.woff b/release/2.2/_static/fonts/FreightSans/freight-sans-book-italic.woff new file mode 100644 index 00000000000..a50e5038a40 Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-book-italic.woff differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-book-italic.woff2 b/release/2.2/_static/fonts/FreightSans/freight-sans-book-italic.woff2 new file mode 100644 index 00000000000..fe284db6614 Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-book-italic.woff2 differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-book.woff b/release/2.2/_static/fonts/FreightSans/freight-sans-book.woff new file mode 100644 index 00000000000..6ab8775f00b Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-book.woff differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-book.woff2 b/release/2.2/_static/fonts/FreightSans/freight-sans-book.woff2 new file mode 100644 index 00000000000..2688739f1f0 Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-book.woff2 differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-light-italic.woff b/release/2.2/_static/fonts/FreightSans/freight-sans-light-italic.woff new file mode 100644 index 00000000000..beda58d4e21 Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-light-italic.woff differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-light-italic.woff2 b/release/2.2/_static/fonts/FreightSans/freight-sans-light-italic.woff2 new file mode 100644 index 00000000000..e2fa0134b1a Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-light-italic.woff2 differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-light.woff b/release/2.2/_static/fonts/FreightSans/freight-sans-light.woff new file mode 100644 index 00000000000..226a0bf8358 Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-light.woff differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-light.woff2 b/release/2.2/_static/fonts/FreightSans/freight-sans-light.woff2 new file mode 100644 index 00000000000..6d8ff2c045b Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-light.woff2 differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-medium-italic.woff b/release/2.2/_static/fonts/FreightSans/freight-sans-medium-italic.woff new file mode 100644 index 00000000000..a42115d63b3 Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-medium-italic.woff differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-medium-italic.woff2 b/release/2.2/_static/fonts/FreightSans/freight-sans-medium-italic.woff2 new file mode 100644 index 00000000000..16a7713a451 Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-medium-italic.woff2 differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-medium.woff b/release/2.2/_static/fonts/FreightSans/freight-sans-medium.woff new file mode 100644 index 00000000000..5ea34539c6f Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-medium.woff differ diff --git a/release/2.2/_static/fonts/FreightSans/freight-sans-medium.woff2 b/release/2.2/_static/fonts/FreightSans/freight-sans-medium.woff2 new file mode 100644 index 00000000000..c58b6a528bb Binary files /dev/null and b/release/2.2/_static/fonts/FreightSans/freight-sans-medium.woff2 differ diff --git a/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff new file mode 100644 index 00000000000..cf37a5c50bd Binary files /dev/null and b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff differ diff --git a/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff2 b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff2 new file mode 100644 index 00000000000..955a6eab5bb Binary files /dev/null and b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff2 differ diff --git a/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff new file mode 100644 index 00000000000..fc65a679c22 Binary files /dev/null and b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff differ diff --git a/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2 b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2 new file mode 100644 index 00000000000..c352e40e34a Binary files /dev/null and b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2 differ diff --git a/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff new file mode 100644 index 00000000000..7d63d89f24b Binary files /dev/null and b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff differ diff --git a/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff2 b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff2 new file mode 100644 index 00000000000..d0d7ded9079 Binary files /dev/null and b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff2 differ diff --git a/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff new file mode 100644 index 00000000000..1da7753cf28 Binary files /dev/null and b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff differ diff --git a/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2 b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2 new file mode 100644 index 00000000000..79dffdb85f7 Binary files /dev/null and b/release/2.2/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2 differ diff --git a/release/2.2/_static/images/arrow-down-orange.svg b/release/2.2/_static/images/arrow-down-orange.svg new file mode 100644 index 00000000000..e9d8e9ecf24 --- /dev/null +++ b/release/2.2/_static/images/arrow-down-orange.svg @@ -0,0 +1,19 @@ + + \ No newline at end of file diff --git a/release/2.2/_static/images/arrow-right-with-tail.svg b/release/2.2/_static/images/arrow-right-with-tail.svg new file mode 100644 index 00000000000..5843588fca6 --- /dev/null +++ b/release/2.2/_static/images/arrow-right-with-tail.svg @@ -0,0 +1,19 @@ + + \ No newline at end of file diff --git a/release/2.2/_static/images/chevron-down-black.svg b/release/2.2/_static/images/chevron-down-black.svg new file mode 100644 index 00000000000..097bc076ecf --- /dev/null +++ b/release/2.2/_static/images/chevron-down-black.svg @@ -0,0 +1,16 @@ + diff --git a/release/2.2/_static/images/chevron-down-grey.svg b/release/2.2/_static/images/chevron-down-grey.svg new file mode 100644 index 00000000000..82d6514f250 --- /dev/null +++ b/release/2.2/_static/images/chevron-down-grey.svg @@ -0,0 +1,18 @@ + + + + diff --git a/release/2.2/_static/images/chevron-down-orange.svg b/release/2.2/_static/images/chevron-down-orange.svg new file mode 100644 index 00000000000..fd79a57854c --- /dev/null +++ b/release/2.2/_static/images/chevron-down-orange.svg @@ -0,0 +1,16 @@ + diff --git a/release/2.2/_static/images/chevron-down-white.svg b/release/2.2/_static/images/chevron-down-white.svg new file mode 100644 index 00000000000..e6c94e27b64 --- /dev/null +++ b/release/2.2/_static/images/chevron-down-white.svg @@ -0,0 +1,16 @@ + diff --git a/release/2.2/_static/images/chevron-right-orange.svg b/release/2.2/_static/images/chevron-right-orange.svg new file mode 100644 index 00000000000..7033fc93bf4 --- /dev/null +++ b/release/2.2/_static/images/chevron-right-orange.svg @@ -0,0 +1,17 @@ + + + + diff --git a/release/2.2/_static/images/chevron-right-white.svg b/release/2.2/_static/images/chevron-right-white.svg new file mode 100644 index 00000000000..dd9e77f2616 --- /dev/null +++ b/release/2.2/_static/images/chevron-right-white.svg @@ -0,0 +1,17 @@ + + + + \ No newline at end of file diff --git a/release/2.2/_static/images/home-footer-background.jpg b/release/2.2/_static/images/home-footer-background.jpg new file mode 100644 index 00000000000..b307bb57f48 Binary files /dev/null and b/release/2.2/_static/images/home-footer-background.jpg differ diff --git a/release/2.2/_static/images/icon-close.svg b/release/2.2/_static/images/icon-close.svg new file mode 100644 index 00000000000..348964e79f7 --- /dev/null +++ b/release/2.2/_static/images/icon-close.svg @@ -0,0 +1,21 @@ + + \ No newline at end of file diff --git a/release/2.2/_static/images/icon-menu-dots-dark.svg b/release/2.2/_static/images/icon-menu-dots-dark.svg new file mode 100644 index 00000000000..fa2ad044b3f --- /dev/null +++ b/release/2.2/_static/images/icon-menu-dots-dark.svg @@ -0,0 +1,42 @@ + + \ No newline at end of file diff --git a/release/2.2/_static/images/logo-dark.svg b/release/2.2/_static/images/logo-dark.svg new file mode 100644 index 00000000000..9b4c1a56ac6 --- /dev/null +++ b/release/2.2/_static/images/logo-dark.svg @@ -0,0 +1,30 @@ + + + + diff --git a/release/2.2/_static/images/logo-facebook-dark.svg b/release/2.2/_static/images/logo-facebook-dark.svg new file mode 100644 index 00000000000..cff17915c4f --- /dev/null +++ b/release/2.2/_static/images/logo-facebook-dark.svg @@ -0,0 +1,8 @@ + + + + diff --git a/release/2.2/_static/images/logo-icon.svg b/release/2.2/_static/images/logo-icon.svg new file mode 100644 index 00000000000..575f6823e47 --- /dev/null +++ b/release/2.2/_static/images/logo-icon.svg @@ -0,0 +1,12 @@ + + + + diff --git a/release/2.2/_static/images/logo-twitter-dark.svg b/release/2.2/_static/images/logo-twitter-dark.svg new file mode 100644 index 00000000000..1572570f88c --- /dev/null +++ b/release/2.2/_static/images/logo-twitter-dark.svg @@ -0,0 +1,16 @@ + + + + diff --git a/release/2.2/_static/images/logo-youtube-dark.svg b/release/2.2/_static/images/logo-youtube-dark.svg new file mode 100644 index 00000000000..e3cfedd79d1 --- /dev/null +++ b/release/2.2/_static/images/logo-youtube-dark.svg @@ -0,0 +1,21 @@ + diff --git a/release/2.2/_static/images/logo.svg b/release/2.2/_static/images/logo.svg new file mode 100644 index 00000000000..f8d44b98425 --- /dev/null +++ b/release/2.2/_static/images/logo.svg @@ -0,0 +1,31 @@ + + + + diff --git a/release/2.2/_static/images/pytorch-colab.svg b/release/2.2/_static/images/pytorch-colab.svg new file mode 100644 index 00000000000..2ab15e2f307 --- /dev/null +++ b/release/2.2/_static/images/pytorch-colab.svg @@ -0,0 +1,24 @@ + + + diff --git a/release/2.2/_static/images/pytorch-download.svg b/release/2.2/_static/images/pytorch-download.svg new file mode 100644 index 00000000000..cc37d638e92 --- /dev/null +++ b/release/2.2/_static/images/pytorch-download.svg @@ -0,0 +1,10 @@ + + + diff --git a/release/2.2/_static/images/pytorch-github.svg b/release/2.2/_static/images/pytorch-github.svg new file mode 100644 index 00000000000..2c2570da1de --- /dev/null +++ b/release/2.2/_static/images/pytorch-github.svg @@ -0,0 +1,15 @@ + + + diff --git a/release/2.2/_static/images/pytorch-x.svg b/release/2.2/_static/images/pytorch-x.svg new file mode 100644 index 00000000000..74856ea9fda --- /dev/null +++ b/release/2.2/_static/images/pytorch-x.svg @@ -0,0 +1,10 @@ + + + diff --git a/release/2.2/_static/images/search-icon.svg b/release/2.2/_static/images/search-icon.svg new file mode 100644 index 00000000000..ebb0df86773 --- /dev/null +++ b/release/2.2/_static/images/search-icon.svg @@ -0,0 +1,19 @@ + + diff --git a/release/2.2/_static/images/view-page-source-icon.svg b/release/2.2/_static/images/view-page-source-icon.svg new file mode 100644 index 00000000000..6f5bbe0748f --- /dev/null +++ b/release/2.2/_static/images/view-page-source-icon.svg @@ -0,0 +1,13 @@ + + + diff --git a/release/2.2/_static/img/pytorch-logo-dark.svg b/release/2.2/_static/img/pytorch-logo-dark.svg new file mode 100644 index 00000000000..717a3ce942f --- /dev/null +++ b/release/2.2/_static/img/pytorch-logo-dark.svg @@ -0,0 +1,24 @@ + + + diff --git a/release/2.2/_static/jquery-3.6.0.js b/release/2.2/_static/jquery-3.6.0.js new file mode 100644 index 00000000000..fc6c299b73e --- /dev/null +++ b/release/2.2/_static/jquery-3.6.0.js @@ -0,0 +1,10881 @@ +/*! + * jQuery JavaScript Library v3.6.0 + * https://jquery.com/ + * + * Includes Sizzle.js + * https://sizzlejs.com/ + * + * Copyright OpenJS Foundation and other contributors + * Released under the MIT license + * https://jquery.org/license + * + * Date: 2021-03-02T17:08Z + */ +( function( global, factory ) { + + "use strict"; + + if ( typeof module === "object" && typeof module.exports === "object" ) { + + // For CommonJS and CommonJS-like environments where a proper `window` + // is present, execute the factory and get jQuery. + // For environments that do not have a `window` with a `document` + // (such as Node.js), expose a factory as module.exports. + // This accentuates the need for the creation of a real `window`. + // e.g. var jQuery = require("jquery")(window); + // See ticket #14549 for more info. + module.exports = global.document ? + factory( global, true ) : + function( w ) { + if ( !w.document ) { + throw new Error( "jQuery requires a window with a document" ); + } + return factory( w ); + }; + } else { + factory( global ); + } + +// Pass this if window is not defined yet +} )( typeof window !== "undefined" ? window : this, function( window, noGlobal ) { + +// Edge <= 12 - 13+, Firefox <=18 - 45+, IE 10 - 11, Safari 5.1 - 9+, iOS 6 - 9.1 +// throw exceptions when non-strict code (e.g., ASP.NET 4.5) accesses strict mode +// arguments.callee.caller (trac-13335). But as of jQuery 3.0 (2016), strict mode should be common +// enough that all such attempts are guarded in a try block. +"use strict"; + +var arr = []; + +var getProto = Object.getPrototypeOf; + +var slice = arr.slice; + +var flat = arr.flat ? function( array ) { + return arr.flat.call( array ); +} : function( array ) { + return arr.concat.apply( [], array ); +}; + + +var push = arr.push; + +var indexOf = arr.indexOf; + +var class2type = {}; + +var toString = class2type.toString; + +var hasOwn = class2type.hasOwnProperty; + +var fnToString = hasOwn.toString; + +var ObjectFunctionString = fnToString.call( Object ); + +var support = {}; + +var isFunction = function isFunction( obj ) { + + // Support: Chrome <=57, Firefox <=52 + // In some browsers, typeof returns "function" for HTML