diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3f58480f9dd..134b65c0cbb 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -179,11 +179,8 @@ "nn.functional.dropout", "nn.functional.embedding_bag", "nn.functional.embedding", - "nn.functional.feature_alpha_dropout", "nn.functional.fractional_max_pool2d", "nn.functional.fractional_max_pool3d", - "nn.functional.gaussian_nll_loss", - "nn.functional.grid_sample", "nn.functional.group_norm", "nn.functional.hinge_embedding_loss", "nn.functional.instance_norm", @@ -325,6 +322,8 @@ 'rand', 'rand_like', 'uniform', + # Dropout is not deterministic https://pytorch.org/docs/stable/generated/torch.nn.functional.feature_alpha_dropout.html + 'nn.functional.feature_alpha_dropout', } atol_dict = {"matrix_exp": (3e-2, 1e-4)} diff --git a/experimental/torch_xla2/torch_xla2/decompositions.py b/experimental/torch_xla2/torch_xla2/decompositions.py index e0782505009..4ef7537ea13 100644 --- a/experimental/torch_xla2/torch_xla2/decompositions.py +++ b/experimental/torch_xla2/torch_xla2/decompositions.py @@ -7,6 +7,7 @@ Can also contain decompositions of a torch op in terms of other torch ops. """ +import functools from typing import Any, Callable, List, Tuple import torch @@ -122,12 +123,165 @@ def bernoulli_float(self, p=0.5): _try_register(aten.bernoulli_.float, bernoulli_float) _try_register(aten.bernoulli_.Tensor, decompositions_for_rng.bernoulli_) + + +def _sum_tensors(ts) -> Tensor: + return functools.reduce(torch.add, ts) + + +@register_decomposition(aten.grid_sampler_3d) +def _grid_sampler_3d( + a: torch.Tensor, + grid: torch.Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> Tensor: + """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 + + The above implement the 2d case. + """ + _expand_grid = False + torch._check( + interpolation_mode in (0, 1), + lambda: f"Invalid interpolation mode {interpolation_mode}", + ) + torch._check( + padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}" + ) + + # a is 5D: [B, C, D, H, W] + + def unnormalize(coords: Tensor, size: int) -> Tensor: + # Rescale coordinates from [-1, 1] to: + # [0, size - 1] if align_corners is True + # [-.5, size -.5] if align_corners is False + mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) + ofs = size * 0.5 - 0.5 + return coords * mul + ofs + + # Reflects coordinates until they fall between low and high (inclusive). + # The bounds are passed as twice their value so that half-integer values + # can be represented as ints. + def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor: + if twice_low == twice_high: + return torch.zeros_like(coords) + coords_min = twice_low / 2 + coords_span = (twice_high - twice_low) / 2 + coords2 = (coords - coords_min).abs() + extra = torch.fmod(coords2, coords_span) + flips = (coords2 / coords_span).floor().to(dtype=torch.int8) + return torch.where( + flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra + ) + + def compute_coordinates(coords: Tensor, size: int) -> Tensor: + if padding_mode == 0: # Zero + return coords + elif padding_mode == 1: # Borders + return torch.clamp(coords, 0, size - 1) + else: # padding_mode == 2, Reflection + if align_corners: + coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) + else: + coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) + return torch.clamp(coords_reflected, 0, size - 1) + + def compute_source_index(coords: Tensor, size: int) -> Tensor: + coords_un = unnormalize(coords, size) + return compute_coordinates(coords_un, size) + + N, C, iD, iH, iW = a.shape + _, oD, oH, oW, three = grid.shape + assert three == 3, 'Last dim of grid must be 3. got {}'.format(three) + + + def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor: + xcheck = torch.logical_and(0 <= xs, xs < iW) + ycheck = torch.logical_and(0 <= ys, ys < iH) + zcheck = torch.logical_and(0 <= zs, zs < iD) + return torch.logical_and( + xcheck, torch.logical_and(ycheck, zcheck) + ) + + N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1) + C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1) + + def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor): + cond = in_bounds_cond(xs, ys, zs) + # To clip to inside valid coordinates, we map the coordinates + # to (x, y) = (0, 0) and also set the weight to 0 + # We also change the shape of the tensor to the appropriate one for + # broadcasting with N_idx, C_idx for the purposes of advanced indexing + c = C if _expand_grid else 1 + return tuple( + torch.where(cond, t, 0).view(N, c, oD, oH, oW) + for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), zs.to(dtype=torch.int64), ws) + ) + + def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tensor: + # Perform clipping, index into input tensor and multiply by weight + idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w) + return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_ + + x = grid[..., 0] + y = grid[..., 1] + d = grid[..., 2] + + if interpolation_mode == 0: # Bilinear + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + id_ = compute_source_index(d, iD) + + ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor() + ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf + ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf + ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf + ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1 + ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1 + ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1 + ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1 + + w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_) + w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb- id_) + w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_) + w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_) + w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef) + w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf) + w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef) + w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf) + + return _sum_tensors( + get_summand(ix, iy, id_, w) + for (ix, iy, id_, w) in ( + (ix_nwf, iy_nwf, id_nwf, w_nwf), + (ix_nef, iy_nef, id_nef, w_nef), + (ix_swf, iy_swf, id_swf, w_swf), + (ix_sef, iy_sef, id_sef, w_sef), + (ix_nwb, iy_nwb, id_nwb, w_nwb), + (ix_neb, iy_neb, id_neb, w_neb), + (ix_swb, iy_swb, id_swb, w_swb), + (ix_seb, iy_seb, id_seb, w_seb), + ) + ) + else: #interpolation_mode == 1: # Nearest + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + iz = compute_source_index(d, iD) + + ix_nearest = ix.round() + iy_nearest = iy.round() + iz_nearest = iz.round() + + return get_summand(ix_nearest, iy_nearest, iz_nearest, 1) + EXTRA_DECOMP = decomp.get_decompositions([ torch.ops.aten.upsample_nearest2d, torch.ops.aten._native_batch_norm_legit.no_stats, torch.ops.aten._adaptive_avg_pool2d, torch.ops.aten._adaptive_avg_pool3d, torch.ops.aten.grid_sampler_2d, + torch.ops.aten.grid_sampler_3d, torch.ops.aten.native_dropout, torch.ops.aten.reflection_pad1d, torch.ops.aten.reflection_pad2d, diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 803f46e46e7..97e8e6daeb5 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -11,6 +11,7 @@ import torch.distributed._functional_collectives from torch_xla2.ops import ops_registry from torch_xla2.ops import op_base, mappings +from torch_xla2 import interop # Keys are OpOverload, value is a callable that takes # XLATensor2 @@ -36,6 +37,7 @@ torch.ops.aten.normal_: torch.ops.aten.normal, torch.ops.aten.squeeze_: torch.ops.aten.squeeze, torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p, + torch.ops.aten.clamp_: torch.ops.aten.clamp, } @@ -1081,7 +1083,7 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): # (Optional) dtype conversion if dtype is not None: - result = result.astype(dtype) + result = result.astype(mappings.t2j_dtype(dtype)) return result @@ -1674,6 +1676,7 @@ def _aten_atan2(input, other): # aten.bitwise_and @op(torch.ops.aten.bitwise_and) +@op(torch.ops.aten.__and__) def _aten_bitwise_and(self, other): return self & other