Skip to content

Commit

Permalink
grid_sampler_3d, gaussian_nll_loss, dropout (#7969)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Sep 6, 2024
1 parent 0528c9d commit 900296a
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 4 deletions.
5 changes: 2 additions & 3 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,8 @@
"nn.functional.dropout",
"nn.functional.embedding_bag",
"nn.functional.embedding",
"nn.functional.feature_alpha_dropout",
"nn.functional.fractional_max_pool2d",
"nn.functional.fractional_max_pool3d",
"nn.functional.gaussian_nll_loss",
"nn.functional.grid_sample",
"nn.functional.group_norm",
"nn.functional.hinge_embedding_loss",
"nn.functional.instance_norm",
Expand Down Expand Up @@ -325,6 +322,8 @@
'rand',
'rand_like',
'uniform',
# Dropout is not deterministic https://pytorch.org/docs/stable/generated/torch.nn.functional.feature_alpha_dropout.html
'nn.functional.feature_alpha_dropout',
}

atol_dict = {"matrix_exp": (3e-2, 1e-4)}
Expand Down
154 changes: 154 additions & 0 deletions experimental/torch_xla2/torch_xla2/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Can also contain decompositions of a torch op in terms of other torch ops.
"""

import functools
from typing import Any, Callable, List, Tuple

import torch
Expand Down Expand Up @@ -122,12 +123,165 @@ def bernoulli_float(self, p=0.5):
_try_register(aten.bernoulli_.float, bernoulli_float)
_try_register(aten.bernoulli_.Tensor, decompositions_for_rng.bernoulli_)



def _sum_tensors(ts) -> Tensor:
return functools.reduce(torch.add, ts)


@register_decomposition(aten.grid_sampler_3d)
def _grid_sampler_3d(
a: torch.Tensor,
grid: torch.Tensor,
interpolation_mode: int = 0,
padding_mode: int = 0,
align_corners: bool = False,
) -> Tensor:
"""References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
The above implement the 2d case.
"""
_expand_grid = False
torch._check(
interpolation_mode in (0, 1),
lambda: f"Invalid interpolation mode {interpolation_mode}",
)
torch._check(
padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
)

# a is 5D: [B, C, D, H, W]

def unnormalize(coords: Tensor, size: int) -> Tensor:
# Rescale coordinates from [-1, 1] to:
# [0, size - 1] if align_corners is True
# [-.5, size -.5] if align_corners is False
mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
ofs = size * 0.5 - 0.5
return coords * mul + ofs

# Reflects coordinates until they fall between low and high (inclusive).
# The bounds are passed as twice their value so that half-integer values
# can be represented as ints.
def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
if twice_low == twice_high:
return torch.zeros_like(coords)
coords_min = twice_low / 2
coords_span = (twice_high - twice_low) / 2
coords2 = (coords - coords_min).abs()
extra = torch.fmod(coords2, coords_span)
flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
return torch.where(
flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
)

def compute_coordinates(coords: Tensor, size: int) -> Tensor:
if padding_mode == 0: # Zero
return coords
elif padding_mode == 1: # Borders
return torch.clamp(coords, 0, size - 1)
else: # padding_mode == 2, Reflection
if align_corners:
coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
else:
coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
return torch.clamp(coords_reflected, 0, size - 1)

def compute_source_index(coords: Tensor, size: int) -> Tensor:
coords_un = unnormalize(coords, size)
return compute_coordinates(coords_un, size)

N, C, iD, iH, iW = a.shape
_, oD, oH, oW, three = grid.shape
assert three == 3, 'Last dim of grid must be 3. got {}'.format(three)


def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
xcheck = torch.logical_and(0 <= xs, xs < iW)
ycheck = torch.logical_and(0 <= ys, ys < iH)
zcheck = torch.logical_and(0 <= zs, zs < iD)
return torch.logical_and(
xcheck, torch.logical_and(ycheck, zcheck)
)

N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1)
C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1)

def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor):
cond = in_bounds_cond(xs, ys, zs)
# To clip to inside valid coordinates, we map the coordinates
# to (x, y) = (0, 0) and also set the weight to 0
# We also change the shape of the tensor to the appropriate one for
# broadcasting with N_idx, C_idx for the purposes of advanced indexing
c = C if _expand_grid else 1
return tuple(
torch.where(cond, t, 0).view(N, c, oD, oH, oW)
for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), zs.to(dtype=torch.int64), ws)
)

def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tensor:
# Perform clipping, index into input tensor and multiply by weight
idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_

x = grid[..., 0]
y = grid[..., 1]
d = grid[..., 2]

if interpolation_mode == 0: # Bilinear
ix = compute_source_index(x, iW)
iy = compute_source_index(y, iH)
id_ = compute_source_index(d, iD)

ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor()
ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf
ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf
ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf
ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1
ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1
ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1
ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1

w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_)
w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb- id_)
w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_)
w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_)
w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef)
w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf)
w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef)
w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)

return _sum_tensors(
get_summand(ix, iy, id_, w)
for (ix, iy, id_, w) in (
(ix_nwf, iy_nwf, id_nwf, w_nwf),
(ix_nef, iy_nef, id_nef, w_nef),
(ix_swf, iy_swf, id_swf, w_swf),
(ix_sef, iy_sef, id_sef, w_sef),
(ix_nwb, iy_nwb, id_nwb, w_nwb),
(ix_neb, iy_neb, id_neb, w_neb),
(ix_swb, iy_swb, id_swb, w_swb),
(ix_seb, iy_seb, id_seb, w_seb),
)
)
else: #interpolation_mode == 1: # Nearest
ix = compute_source_index(x, iW)
iy = compute_source_index(y, iH)
iz = compute_source_index(d, iD)

ix_nearest = ix.round()
iy_nearest = iy.round()
iz_nearest = iz.round()

return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)

EXTRA_DECOMP = decomp.get_decompositions([
torch.ops.aten.upsample_nearest2d,
torch.ops.aten._native_batch_norm_legit.no_stats,
torch.ops.aten._adaptive_avg_pool2d,
torch.ops.aten._adaptive_avg_pool3d,
torch.ops.aten.grid_sampler_2d,
torch.ops.aten.grid_sampler_3d,
torch.ops.aten.native_dropout,
torch.ops.aten.reflection_pad1d,
torch.ops.aten.reflection_pad2d,
Expand Down
5 changes: 4 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.distributed._functional_collectives
from torch_xla2.ops import ops_registry
from torch_xla2.ops import op_base, mappings
from torch_xla2 import interop

# Keys are OpOverload, value is a callable that takes
# XLATensor2
Expand All @@ -36,6 +37,7 @@
torch.ops.aten.normal_: torch.ops.aten.normal,
torch.ops.aten.squeeze_: torch.ops.aten.squeeze,
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
torch.ops.aten.clamp_: torch.ops.aten.clamp,
}


Expand Down Expand Up @@ -1081,7 +1083,7 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None):

# (Optional) dtype conversion
if dtype is not None:
result = result.astype(dtype)
result = result.astype(mappings.t2j_dtype(dtype))

return result

Expand Down Expand Up @@ -1674,6 +1676,7 @@ def _aten_atan2(input, other):

# aten.bitwise_and
@op(torch.ops.aten.bitwise_and)
@op(torch.ops.aten.__and__)
def _aten_bitwise_and(self, other):
return self & other

Expand Down

0 comments on commit 900296a

Please sign in to comment.